PaddleOCR-VL 作为一款开源多模态模型,凭借其强大的文本检测、识别与图文关联分析能力,在工业质检、智能文档审核、车载多模态OCR 等场景中已取得广泛应用。然而,PaddleOCR-VL 与昇腾服务器存在兼容性问题,导致其无法在昇腾服务器上直接部署与运行。为此,本文重点完成 PaddleOCR-VL 在昇腾A3平台的适配工作,实现模型在昇腾服务器上的稳定部署与推理,全面支持文本检测、文字识别及图文语义关联等核心多模态功能。同时,通过容器化封装,构建可一键启动、跨环境复用的标准化部署方案,显著提升模型在实际业务场景中的可落地性。
| 组件 | 版本 |
|---|---|
| Atlas 900 A3 | 910C |
| Ascend HDK | 25.2.3 |
| 操作系统 | openEuler 22.03 (LTS-SP4) |
| CANN | 8.3.RC1 |
PaddleOCR-VL其实就是两个模型组成,一个版面分析模型PP-DocLayoutV2和一个VLM模型PaddleOCR-VL-0.9B,核心是PaddleOCR-VL-0.9B。
进一步的,PP-DocLayoutV2是在基础版面检测模型PP-DocLayout_plus-L的基础上级联一个含6层Transformer的指针网络;PaddleOCR-VL-0.9B基于一个动态分辨率视觉编码器与ERNIE-4.5-0.3B语言模型。
权重下载:PaddleOCR-VL · 模型库
vllm版本:0.11.1
git clone https://github.com/vllm-project/vllm.git
git checkout v0.11.1
git clone https://github.com/vllm-project/vllm-ascend.git
git reset --hard 6664a4e
python -m pip install paddlepaddle==3.2.0
python -m pip install -U "paddleocr[doc-parser]"
pip install safetensors
可能缺少opencv组件:
apt-get install -y libgl1 libglib2.0-0
注意:不同的操作系统命令不同,根据当前操作系统进行切换,openEuler操作系统对应的命令是:
yum install -y mesa-libGL yum install -y glib2
第26行添加
from vllm.compilation.decorators import support_torch_compile
在32行添加
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
)
在Ovis2_5下面一行增加注册相关代码
"PaddleOCRVLForConditionalGeneration": (
"paddleocr_vl",
"PaddleOCRVLForConditionalGeneration",
),
文件见附录
vllm serve PaddleOCR-VL-0.9B
--trust-remote-code
--max-num-batched-tokens 16384
--no-enable-prefix-caching
--mm-processor-cache-gb 0
from paddleocr import PaddleOCRVL
pipeline = PaddleOCRVL(
layout_detection_model_dir='[YOUR_PATH]/PaddleOCR-VL-0.9B/PP-DocLayoutV2',
use_doc_orientation_classify=False,
vl_rec_backend="vllm-server", vl_rec_server_url="http://127.0.0.1:8000/v1",
use_doc_unwarping=False)
output = pipeline.predict("./image_path")
for res in output:
res.print()
res.save_to_json(save_path="output")
res.save_to_markdown(save_path="output")

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 PaddlePaddle PaddlePaddle。保留所有权利。
#
# 根据Apache许可证2.0版("许可证")获得许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件
# 按"原样"分发,不附带任何明示或暗示的担保或条件。
# 有关许可证下权限和限制的具体规定,请参阅许可证。
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Literal
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import BatchFeature, PretrainedConfig
from transformers.activations import GELUActivation
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
)
from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_xformers_attn_wrapper,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargs,
)
from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .ernie45 import Ernie4_5ForCausalLM
from .interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
is_pp_missing_parameter,
maybe_prefix,
)
from .vision import get_vit_attn_backend
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 28 * 28 * 130,
max_pixels: int = 28 * 28 * 1280,
):
"""智能调整图像尺寸,使其满足以下条件:
1. 高度和宽度两个维度均能被'factor'整除。
2. 像素总数在['min_pixels','max_pixels']范围内。
3. 尽可能保持图像的原始宽高比。
"""
if height < factor:
width = round((width * factor) / height)
height = factor
if width < factor:
height = round((height * factor) / width)
width = factor
if max(height, width) / min(height, width) > 200:
raise ValueError(
f"图像的宽高比绝对值必须小于200,当前为{max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bardef rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) 或 (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = rotary_emb_function(t_, cos, sin).type_as(t)
return outputclass PaddleOCRVLProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(** kwargs)
def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(** kwargs).image_processor
def get_supported_mm_limits(self):
return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
image_processor,
) -> int:
if image_processor is None:
image_processor = self.get_image_processor()
do_resize = True
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width, height=image_height)
grid_t = 1
grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size
num_patches = grid_t * grid_h * grid_w
num_image_tokens = num_patches // (merge_size ** 2)
return num_image_tokens
def get_image_size_with_most_features(self) -> ImageSize:
hf_config = self.get_hf_config()
image_size = hf_config.vision_config.image_size
return ImageSize(height=image_size, width=image_size)class PaddleOCRVLDummyInputsBuilder(BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
max_image_size = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=max_image_size.width,
height=max_image_size.height,
num_images=num_images,
overrides=image_overrides,
)
}class PaddleOCRVLMultiModalProcessor(
BaseMultiModalProcessor[PaddleOCRVLProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, ** mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
num_patches_per_image = processed_outputs["image_grid_thw"].prod(-1)
processed_outputs["pixel_values"] = processed_outputs["pixel_values"].split(
num_patches_per_image.tolist()
)
else:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(
prompt, add_special_tokens=True, return_tensors="pt"
)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_grid_thw=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_id
def get_replacement(item_idx: int, image_processor):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
image_processor=image_processor,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=partial(get_replacement, image_processor=image_processor),
),
]class Projector(nn.Module):
def __init__(
self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
prefix: str = "",
):
super().__init__()
self.text_config = text_config
self.vision_config = vision_config
self.merge_kernel_size = (2, 2)
self.hidden_size = (
self.vision_config.hidden_size
* self.merge_kernel_size[0]
* self.merge_kernel_size[1]
)
self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05)
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.act = GELUActivation()
self.linear_2 = nn.Linear(
self.hidden_size, self.text_config.hidden_size, bias=True
)
def forward(
self,
image_features: torch.Tensor,
image_grid_thw: torch.Tensor,
) -> torch.Tensor:
m1, m2 = self.merge_kernel_size
if isinstance(image_features, (list, tuple)):
processed_features = list()
for image_feature, image_grid in zip(image_features, image_grid_thw):
image_feature = self.pre_norm(image_feature)
t, h, w = image_grid
image_feature = rearrange(
image_feature,
"(t h p1 w p2) d -> (t h w) (p1 p2 d)",
t=t,
h=h // m1,
p1=m1,
w=w // m2,
p2=m2,
)
hidden_states = self.linear_1(image_feature)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
processed_features.append(hidden_states)
return processed_features
dims = image_features.shape[:-1]
dim = image_features.shape[-1]
image_features = image_features.view(np.prod(dims), dim)
hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states.view(*dims, -1)class PaddleOCRImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: Annotated[
torch.Tensor,
TensorShape("bn", "p", 3, "patch_size", "patch_size", dynamic_dims={"p"}),
]
image_grid_thw: Annotated[
torch.Tensor,
TensorShape("bn", 3),
]class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.cache_position_embedding = dict()
self.cache_position_count = dict()
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False,
)
def interpolate_pos_encoding(
self,
embeddings: torch.Tensor,
height: int,
width: int,
is_after_patchify: bool = False,
) -> torch.Tensor:
num_positions = self.position_embedding.weight.shape[0]
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
dim = embeddings.shape[-1]
if is_after_patchify:
new_height = height
new_width = width
else:
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions ** 0.5)
patch_pos_embed = patch_pos_embed.reshape(
1, sqrt_num_positions, sqrt_num_positions, dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bilinear",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def fetch_position_embedding_lfu_cache(
self, embeddings: torch.Tensor, h: int, w: int, max_cache: int = 20
):
grid = (h, w)
if grid in self.cache_position_embedding:
self.cache_position_count[grid] += 1
return self.cache_position_embedding[grid]
if len(self.cache_position_embedding) >= max_cache:
min_hit_grid = min(
self.cache_position_count,
key=self.cache_position_count.get,
)
self.cache_position_count.pop(min_hit_grid)
self.cache_position_embedding.pop(min_hit_grid)
position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
self.cache_position_count[grid] = 1
self.cache_position_embedding[grid] = position_embedding
return position_embedding
def forward(
self,
pixel_values: torch.FloatTensor,
position_ids: torch.Tensor | None = None,
image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
| None = None,
interpolate_pos_encoding=False,
) -> torch.Tensor:
if pixel_values.dim() == 4:
pixel_values = pixel_values.unsqueeze(0)
if pixel_values.dim() == 5:
if position_ids is None:
raise ValueError(
"当pixel_values的维度为5时,position_ids不能为None。"
)
(
batch_size,
squence_len,
channel,
height,
width,
) = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
embeddings = patch_embeds.flatten(-2).squeeze(-1)
if interpolate_pos_encoding and image_grid_thw is not None:
start = 0
tmp_embeddings = list()
for image_grid in image_grid_thw:
t, h, w = image_grid
end = start + t * h * w
image_embeddings = embeddings[start:end, :]
position_embedding = (
self.interpolate_pos_encoding(image_embeddings, h, w, True)
.squeeze(0)
.repeat(t, 1)
)
image_embeddings = image_embeddings + position_embedding
tmp_embeddings.append(image_embeddings)
start = end
embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
else:
embeddings = embeddings + self.packing_position_embedding(position_ids)
return embeddings
else:
raise ValueError(
"不支持的pixel_values维度:"
f"{pixel_values.dim()}。预期维度为4或5。"
)def all_gather_interleave(local_tensor: torch.Tensor, hidden_size: int, tp_size: int):
"""跨模型并行组对输入张量进行交错式全收集。"""
import torch.distributed as dist
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
dist.all_gather(
gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group
)
gathered_tensors_split = [
torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
]
ordered_tensors = [
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
]
result_tensor = torch.cat(ordered_tensors, dim=-1)
return result_tensorclass SiglipAttention(nn.Module):
"""基于Qwen2.5-VisionAttention改编的SigLIP视觉注意力模块。"""
def __init__(
self,
*,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend: _Backend = _Backend.TORCH_SDPA,
attn_backend_override: _Backend | None = None,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size
)
self.qkv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
seq_len, bs, _ = qkv.shape
if self.tp_size > 1:
qkv = all_gather_interleave(qkv, self.qkv_proj.hidden_size, self.tp_size)
q, k, v = qkv.chunk(3, dim=2)
if self.tp_size > 1:
splitter = partial(
dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
new_shape = (
seq_len,
bs,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v
def forward(
self,
hidden_states: torch.Tensor,
*,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
seqlens: torch.Tensor | None,
) -> torch.Tensor:
batch_size, _, _ = hidden_states.shape
x = rearrange(hidden_states, "b s d -> s b d")
x, _ = self.qkv_proj(x)
q, k, v = self.split_qkv(x)
q, k, v = (rearrange(t, "s b h d -> b s h d") for t in (q, k, v))
if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
if max_seqlen is None:
raise ValueError("Flash attention后端需要max_seqlen参数。")
context_layer = vit_flash_attn_wrapper(
q,
k,
v,
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == _Backend.ROCM_AITER_FA,
self.use_upstream_fa,
)
elif self.attn_backend == _Backend.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(tensor, "b s h d -> b h s d")
for tensor in (q_i, k_i, v_i)
)output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == _Backend.XFORMERS:
if seqlens is None:
raise ValueError("xFormers 注意力后端需要 seqlens 张量。")
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
else:
raise RuntimeError(
f"PaddleOCR-VL 目前不支持 {self.attn_backend} 后端。"
)
output, _ = self.out_proj(context_layer)
output = rearrange(output, "s b d -> b s d")
return output
class SigLIPRotaryEmbedding(nn.Module):
def init(self, dim: int, theta: float = 10000.0) -> None:
super().init()
self.dim = dim
self.theta = theta
self.rope_init()
def rope_init(self):
inv_freq = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(
seqlen,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype,
)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class SiglipMLP(nn.Module):
def init(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().init()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
quantizable = True
else:
quantizable = (
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def init(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
attn_backend: _Backend = _Backend.TORCH_SDPA,
attn_backend_override: _Backend | None = None,
use_upstream_fa: bool = False,
):
super().init()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.self_attn = SiglipAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
projection_size=config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
use_upstream_fa=use_upstream_fa,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward(
self,
hidden_states: torch.Tensor,
*,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
seqlens: torch.Tensor | None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class SiglipEncoder(nn.Module):
def init(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
):
super().init()
self.config = config
embed_dim = config.hidden_size
num_heads = config.num_attention_heads
head_dim = embed_dim // num_heads
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError(
f"PaddleOCR-VL 目前不支持 {self.attn_backend} 后端。"
)
self.layers = nn.ModuleList(
[
SiglipEncoderLayer(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
use_upstream_fa=self.use_upstream_fa,
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
@staticmethod
def flatten_list(image_grid_thw):
tmp_image_grid_thw = list()
for image_grid in image_grid_thw:
if isinstance(image_grid, list):
tmp_image_grid_thw.extend(image_grid)
else:
tmp_image_grid_thw.append(image_grid)
return tmp_image_grid_thw
def forward(
self,
inputs_embeds,
cu_seqlens: torch.Tensor | None = None,
image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
| None = None,
height_position_ids: torch.Tensor | None = None,
width_position_ids: torch.Tensor | None = None,
) -> torch.Tensor:
device = inputs_embeds.device
hidden_states = inputs_embeds
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
if width_position_ids is None or height_position_ids is None:
split_hids = list()
split_wids = list()
for t, h, w in flatten_image_grid_thw:
image_pids = torch.arange(t * h * w, device=device) % (h * w)
sample_hids = image_pids // w
sample_wids = image_pids % w
split_hids.append(sample_hids)
split_wids.append(sample_wids)
width_position_ids = torch.concat(split_wids, dim=0)
height_position_ids = torch.concat(split_hids, dim=0)
pids = torch.stack(
[height_position_ids, width_position_ids],
dim=-1,
)
max_grid_size = pids.max() + 1
rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rope_emb_max_grid[pids].flatten(1)
if cu_seqlens is None:
raise ValueError("SiglipEncoder 的 cu_seqlens 不能为 None。")
if not isinstance(cu_seqlens, torch.Tensor):
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
else:
cu_seqlens = cu_seqlens.to(device=device)
max_seqlen = None
seqlens = None
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def init(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
):
super().init()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
attn_backend_override=attn_backend_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool | None = False,
position_ids: torch.Tensor | None = None,
height_position_ids: torch.Tensor | None = None,
width_position_ids: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
image_grid_thw: torch.Tensor | None = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
position_ids=position_ids,
image_grid_thw=image_grid_thw,
)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
cu_seqlens=cu_seqlens,
image_grid_thw=image_grid_thw,
height_position_ids=height_position_ids,
width_position_ids=width_position_ids,
)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class SiglipVisionModel(nn.Module):
def init(
self,
config,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
):
super().init()
self.vision_model = SiglipVisionTransformer(
config,
quant_config=quant_config,
prefix=f"{prefix}.vision_model",
attn_backend_override=attn_backend_override,
)
self.quant_config = quant_config
@property
def dtype(self) -> torch.dtype:
return self.vision_model.embeddings.patch_embedding.weight.dtype
@property
def device(self) -> torch.device:
return self.vision_model.embeddings.patch_embedding.weight.device
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
pixel_values,
interpolate_pos_encoding: bool = False,
position_ids: torch.Tensor | None = None,
image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
| None = None,
cu_seqlens: torch.Tensor | None = None,
) -> BaseModelOutputWithPooling:
return self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
position_ids=position_ids,
image_grid_thw=image_grid_thw,
cu_seqlens=cu_seqlens,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "head.attention" in name or "head.layernorm" in name:
continue
if "head.mlp" in name or "head.probe" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(
param,
"weight_loader",
default_weight_loader,
)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (
param_name,
weight_name,
shard_id,
) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param,
"weight_loader",
default_weight_loader,
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@MULTIMODAL_REGISTRY.register_processor(
PaddleOCRVLMultiModalProcessor,
info=PaddleOCRVLProcessingInfo,
dummy_inputs=PaddleOCRVLDummyInputsBuilder,
)
class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsMRoPE):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.": "language_model.model.",
"lm_head.": "language_model.lm_head.",
}
)
def init(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().init()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = SiglipVisionModel(
config=config.vision_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
attn_backend_override=attn_backend_override,
)
self.mlp_AR = Projector(config, config.vision_config)
self.language_model = Ernie4_5ForCausalLM(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
for layer in self.language_model.model.layers:
if not isinstance(layer, PPMissingLayer):
layer.self_attn.rotary_emb.is_neox_style = True
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""获取 mrope 输入位置和 delta 值。"""
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_videos > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
video_second_per_grid_t
tokens_per_second
)
.long()
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
def get_language_model(self) -> nn.Module:
return self.language_model
def _parse_and_validate_image_input(
self, **kwargs: object
) -> PaddleOCRImagePixelInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None:
return None
return PaddleOCRImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
):
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(** kwargs)
is_multimodal = kwargs.pop("is_multimodal", None)
handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
input_ids = None
return self.language_model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"
raise ValueError("仅支持图像模态")
def encode_image(
self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor
) -> torch.Tensor:
pixel_values = pixel_values.type(self.visual.dtype)
siglip_position_ids = list()
image_grid_hws = list()
cu_seqlens = [0]
thw_tuple = tuple(image_grid_thw.tolist())
numel = np.prod(thw_tuple)
image_grid_hws.append(thw_tuple)
image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
siglip_position_ids.append(image_position_ids)
cu_seqlens.append(cu_seqlens[-1] + numel)
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
pixel_values.device
)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(pixel_values.device)
vision_outputs = self.visual(
pixel_values=pixel_values,
image_grid_thw=image_grid_hws,
position_ids=siglip_position_ids,
interpolate_pos_encoding=True,
cu_seqlens=cu_seqlens,
)
return vision_outputs
def _process_image_input(
self, image_input: PaddleOCRImagePixelInputs
) -> MultiModalEmbeddings:
pixel_values = image_input.pixel_values
image_grid_thw = image_input.image_grid_thw
vision_outputs = tuple(
self.encode_image(pixel, grid).squeeze(0)
for pixel, grid in zip(pixel_values, image_grid_thw)
)
image_embeds = self.mlp_AR(vision_outputs, image_grid_thw)
return image_embeds
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(** kwargs)
if image_input is None:
return ()
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
image_embeds = self._process_image_input(image_input)
multimodal_embeddings += tuple(image_embeds)
return multimodal_embeddings
def load_weights(self, weights: Iterable
]) -> set
:
loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return autoloaded_weights