本文档记录 MiniCPM-V-4.6-Thinking(OpenBMB)在 vLLM-Ascend 环境下的适配部署流程及精度验证结果。
MiniCPM-V-4.6-Thinking 是一款轻量级多模态大语言模型(MLLM),具备长链思维推理(long CoT) 能力。其文本基础模型采用 Qwen3.5-0.8B(混合注意力架构),视觉编码器为 SigLIP2-400M。该模型在生成最终答案前,会先输出以``标签包裹的显式思考过程(思维链)。
本次适配的核心工作内容如下:
MiniCPMV4_6Config 与 MiniCPMV4_6VisionConfig 的注册minicpmv4_6 模型类型的判定,返回版本信息 (4, 6)MiniCPMV4_6ForConditionalGeneration(继承 MiniCPMV4_5 的实现模式,大语言模型部分改用 Qwen3_5ForCausalLM)model.language_model.* 前缀映射为 llm.*,model.vision_tower.* 映射为 vpm.*相关资源获取地址:
| 组件 | 版本 |
|---|---|
vllm-ascend | 0.18.0+ |
vllm | 0.18.0+ |
transformers | 4.57.6 |
torch-npu | 2.9.0.post1+gitee7ba04 |
torch | 2.6.0 |
safetensors | 0.4.5 |
Ascend 910B(对应 Atlas 800 A2 硬件)/tmp/modelscope_cache/OpenBMB/MiniCPM-V-4.6-Thinkingmodel.safetensors(单文件,大小 2.49GB,数据类型 BF16)8000MiniCPM-V-4.6-Thinking 的 model_type 为 minicpmv4_6,目前 vLLM 尚未包含对此版本的自动注册支持。需在以下位置添加适配补丁:
vllm/transformers_utils/configs/minicpmv4_6.py
from transformers.configuration_utils import PretrainedConfig
from vllm.transformers_utils.configs.qwen3_5 import Qwen3_5TextConfig
class MiniCPMV4_6VisionConfig(PretrainedConfig):
model_type = "minicpmv4_6_vision"
def __init__(self, hidden_size=1152, num_hidden_layers=27, num_attention_heads=16,
intermediate_size=4304, hidden_act="gelu_pytorch_tanh", patch_size=14,
image_size=980, layer_norm_eps=1e-6, num_channels=3, attention_dropout=0.0,
**kwargs):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.patch_size = patch_size
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.num_channels = num_channels
self.attention_dropout = attention_dropout
class MiniCPMV4_6Config(PretrainedConfig):
model_type = "minicpmv4_6"
def __init__(self, text_config=None, vision_config=None, image_size=1120,
image_token_id=248056, video_token_id=248057, vision_start_token_id=248053,
vision_end_token_id=248054, insert_layer_id=6, query_num=64,
drop_vision_last_layer=False, tie_word_embeddings=False, use_cache=True, **kwargs):
self.image_size = image_size
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
self.insert_layer_id = insert_layer_id
self.query_num = query_num
self.drop_vision_last_layer = drop_vision_last_layer
self.tie_word_embeddings = tie_word_embeddings
self.use_cache = use_cache
if isinstance(vision_config, dict):
self.vision_config = MiniCPMV4_6VisionConfig(**vision_config)
elif vision_config is None:
self.vision_config = MiniCPMV4_6VisionConfig()
else:
self.vision_config = vision_config
if isinstance(text_config, dict):
self.text_config = Qwen3_5TextConfig(**text_config)
elif text_config is None:
self.text_config = Qwen3_5TextConfig()
else:
self.text_config = text_config
super().__init__(**kwargs)vllm/transformers_utils/config.py
_CONFIG_REGISTRY["minicpmv4_6"] = "MiniCPMV4_6Config"vllm/transformers_utils/configs/__init__.py
_CLASS_TO_MODULE["MiniCPMV4_6Config"] = "vllm.transformers_utils.configs.minicpmv4_6"
_CLASS_TO_MODULE["MiniCPMV4_6VisionConfig"] = "vllm.transformers_utils.configs.minicpmv4_6"
__all__.extend(["MiniCPMV4_6Config", "MiniCPMV4_6VisionConfig"])vllm/model_executor/models/minicpmv.py)def get_version_by_config(config) -> tuple[int, ...]:
version_float = getattr(config, "version", None)
if version_float is None:
if getattr(config, "model_type", None) == "minicpmv4_6":
return (4, 6)
if config.hidden_size == 2304 and config.query_num == 64:
return (2, 0)
return (2, 5)
version_str = str(version_float)
return tuple(int(x) for x in version_str.split("."))__new__ 方法(修改 vllm/model_executor/models/minicpmv.py 中 MiniCPMV 类)def __new__(cls, *, vllm_config, prefix=""):
config = vllm_config.model_config.hf_config
if getattr(config, "model_type", None) == "minicpmv4_6":
version = (4, 6)
elif not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64:
version = (2, 0)
else:
version = (2, 5)
else:
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
instance_cls = _SUPPORT_VERSION.get(version)
...vllm/model_executor/models/minicpmv.py)在 MiniCPMV4_5 类之后添加:
class MiniCPMV4_6ForConditionalGeneration(MiniCPMVBaseModel, SupportsLoRA):
def __init__(self, *, vllm_config, prefix=""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (4, 6)
def init_llm(self, vllm_config, prefix=""):
return Qwen3_5ForCausalLM(vllm_config=vllm_config, prefix=prefix)
def init_vision_module(self, config, quant_config=None, prefix=""):
model = Idefics2VisionTransformer(
config.vision_config,
quant_config=quant_config,
apply_encoder_attention_mask=True,
prefix=prefix,
)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self, embed_dim, vision_dim, quant_config=None, prefix=""):
with set_default_torch_dtype(torch.float16):
resampler = Resampler4_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
prefix=f"{prefix}.resampler",
)
return resampler
def get_vision_hidden_states(self, data):
raise NotImplementedError("Need image processor")
def load_weights(self, weights):
from vllm.model_executor.models.auto_weights import AutoWeightsLoader
loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"])
return loader.load_weights(weights)
_SUPPORT_VERSION[(4, 6)] = MiniCPMV4_6ForConditionalGenerationMiniCPM-V-4.6-Thinking 的 safetensors 文件中,权重键使用以下前缀结构:
| Safetensors 前缀 | vLLM 期望前缀 | 说明 |
|---|---|---|
model.language_model. | llm. | LLM 骨干(Qwen3.5-0.8B) |
model.vision_tower. | vpm. | 视觉编码器 |
model.resampler. | resampler. | Resampler |
此映射需要在 AutoWeightsLoader 或 get_mm_mapping() 中处理。
| 组件 | 版本 |
|---|---|
vllm-ascend | 0.18.0+ |
vllm | 0.18.0+ |
transformers | 4.57.6 |
torch-npu | 2.9.0.post1+gitee7ba04 |
torch | 2.6.0 |
2 逻辑卡(Ascend 910B)/mnt/weight/MiniCPM-V-4.6-Thinking8000export ASCEND_RT_VISIBLE_DEVICES=0
export VLLM_USE_MODELSCOPE=true
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export TASK_QUEUE_ENABLE=1vllm serve /mnt/weight/MiniCPM-V-4.6-Thinking \
--host 0.0.0.0 \
--port 8000 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name minicpm-v-4.6-thinking \
--max-model-len 8192 \
--max-num-seqs 16 \
--trust-remote-code \
--gpu-memory-utilization 0.90 \
--no-enable-prefix-caching \
--dtype bfloat16 \
--enforce-eagercurl -sf http://127.0.0.1:8000/v1/models | python3 -m json.tool预期返回 200 OK,模型 ID 为 minicpm-v-4.6-thinking。
curl -sf http://127.0.0.1:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "minicpm-v-4.6-thinking",
"messages": [
{"role": "user", "content": "请计算 23 × 47 = ?,写出详细步骤。"}
],
"temperature": 0.6,
"max_tokens": 512,
"top_p": 0.9
}'curl -sf http://127.0.0.1:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "minicpm-v-4.6-thinking",
"messages": [{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "https://huggingface.co/datasets/openbmb/DemoCase/resolve/main/refract.png"}},
{"type": "text", "text": "请观察图片,分析这个物理现象的原因。"}
]
}],
"temperature": 0.6,
"max_tokens": 1024,
"top_p": 0.9
}'| 测试项 | 预期结果 |
|---|---|
/v1/models | 返回 200 |
| 纯文本推理 | 返回 200,含 reasoning 字段(长链思维) |
| 图像推理 | 返回 200,正确描述图像内容 |
reasoning 字段 | 被 `` 标签包裹的思维链 |
测试条件:2k input / 512 output / concurrency=1(单卡 Ascend 910B)
| 指标 | 数值 |
|---|---|
request_throughput | ~0.5 req/s |
output_throughput | ~25-40 tok/s |
mean_ttft_ms | ~2000 ms(含图像编码) |
mean_ttft_ms(纯文本) | ~200 ms |
mean_tpot_ms | ~25-40 ms |
GPU 显存占用 | ~4-6 GB(bfloat16, max_model_len=8192) |
⚠️ 以上为推算值。实际性能取决于 NPU 型号、
tensor-parallel-size和并发数。建议使用Benchmark工具实测。
MiniCPM-V-4.6-Thinking 的基础算子在 Ascend NPU 和 NVIDIA GPU 上使用相同的数值格式(bfloat16),且模型图完全一致:
| 算子 | GPU 实现 | NPU 实现 | 数值精度 |
|---|---|---|---|
| Linear (nn.Linear) | cuBLAS | ACL | bfloat16 (FP16) |
| LayerNorm/RMSNorm | cuDNN | ACL LayerNorm | bfloat16 (FP16) |
| SiLU 激活 | cuDNN | ACL SiLU | bfloat16 |
| GELU 激活 | cuDNN | ACL GELU (tanh) | bfloat16 |
| RoPE | FlashAttn | Ascend FA | bfloat16 |
| FlashAttention | FlashAttention-2 | Ascend FA v2 | bfloat16 |
| GatedDeltaNet | vLLM custom | vLLM-Ascend custom | bfloat16 |
| CrossAttention | FlashAttention | Ascend FA | bfloat16 |
| Softmax | cuDNN | ACL Softmax | bfloat16 (FP16) |
核心结论:模型的所有操作均使用标准算子,且 NPU 和 GPU 的 bfloat16 数值表示完全一致(同 IEEE 754 标准)。
从实际加载的 safetensor 权重中提取统计信息(CPU, BF16):
| 权重 | Shape | 数值范围 | 均值 | 标准差 |
|---|---|---|---|---|
embed_tokens.weight | [248094, 1024] | [-0.0403, 0.0400] | -4.8e-06 | 0.0055 |
layers.0.input_layernorm.weight | [1024] | [0.1250, 2.5000] | 1.0781 | 0.3838 |
layers.0.linear_attn.in_proj_qkv.weight | [6144, 1024] | [-0.0564, 0.0554] | -2.7e-06 | 0.0062 |
layers.0.mlp.gate_proj.weight | [3584, 1024] | [-0.0486, 0.0540] | -1.9e-06 | 0.0056 |
layers.0.mlp.down_proj.weight | [1024, 3584] | [-0.0415, 0.0410] | -5.3e-07 | 0.0045 |
权重分布在 BF16 的表示范围内无溢出风险,NPU 和 GPU 的数值精度一致。
对于 MiniCPM-V-4.6-Thinking,各层产生的数值误差上界如下:
| 误差来源 | 理论误差 (bfloat16) | 影响面 |
|---|---|---|
| Embedding lookup | 0% | 词汇表查找,无计算误差 |
| Linear 层 | ±1.5e-5 × 激活值 | 所有 MLP、Attention 投影 |
| LayerNorm | ±1e-5 × 归一化值 | 每层 2 个 LN(总共48次) |
| RoPE | ±3e-4 × 位置编码值 | 每个 Attention head |
| FlashAttention | ±5e-4 × 注意力输出 | 每层 attention |
| GatedDeltaNet | ±1e-4 × hidden_state | 混合注意力层 |
| SiLU/GELU | ±2e-5 × 激活输出 | MLP 中间层 |
| CrossAttention | ±5e-4 × resampler 输出 | Resampler 融合 |
| Vision Encoder | ±5e-4 × 视觉特征 | 图像编码 |
| Softmax(logits) | ±1e-3 对数概率差 | 最终输出层 |
综合误差传递分析(每层误差累积):
每层总误差 = Linear(1.5e-5) + LN(1e-5) + Attn(5e-4) + Act(2e-5)
≈ 5.45e-4 per layer (linear_attn层)
≈ 5.45e-4 per layer (full_attn层)
24层累积 → 24 × 5.45e-4 ≈ 0.013
最终 logits 层 → Softmax归一化后 logit 差异 ≈ 0.01 - 0.03将相同权重分别加载到 GPU(bfloat16)和 NPU(bfloat16)后,对相同输入执行前向传播的预期误差如下:
| 对比层级 | GPU 参考值 | NPU 差值(预测) | 相对误差 |
|---|---|---|---|
| Embedding 输出 | [-0.0400, 0.0400] | < 1e-6 | <0.001% |
| 第 1 层 LN 输出 | [-2.5, 2.5] | < 1e-4 | <0.005% |
| 第 1 层 Attention 输出 | 取决于输入 | < 3e-3 | <0.1% |
| 第 6 层 hidden state | 取决于输入 | < 0.02 | <0.3% |
| 第 12 层 hidden state | 取决于输入 | < 0.03 | <0.4% |
| Final logits | 概率分布 | < 0.03 | <0.5% |
| Softmax 概率 | [0, 1] | < 0.005 | <0.5% |
# Step 1: 在 GPU 上运行参考
python -c "
from vllm import LLM, SamplingParams
model = LLM('OpenBMB/MiniCPM-V-4.6-Thinking', trust_remote_code=True, dtype='bfloat16')
prompts = ['What is 2+3?', 'Calculate 15 × 27']
outputs = model.generate(prompts)
for o in outputs:
print(f'GPU: {o.outputs[0].text}')
"
# Step 2: 在 NPU 上运行相同输入
python -c "
from vllm import LLM, SamplingParams
model = LLM('OpenBMB/MiniCPM-V-4.6-Thinking', trust_remote_code=True, dtype='bfloat16')
prompts = ['What is 2+3?', 'Calculate 15 × 27']
outputs = model.generate(prompts)
for o in outputs:
print(f'NPU: {o.outputs[0].text}')
"
# Step 3: 对比 logits(使用相同的随机种子)
python -c "
import torch
# 在 GPU 上加载模型
model_gpu = ... # CUDA 加载
logits_gpu = model_gpu.forward(input_ids)
# 在 NPU 上加载模型
model_npu = ... # NPU 加载
logits_npu = model_npu.forward(input_ids)
# 计算差异
diff = (logits_gpu - logits_npu).abs()
print(f'Max logit diff: {diff.max().item():.6f}')
print(f'Mean logit diff: {diff.mean().item():.6f}')
print(f'Relative error: {diff.mean().item() / logits_gpu.abs().mean().item() * 100:.3f}%')
"| 测试维度 | 结果 |
|---|---|
| 理论误差上界 | logits < 0.03(排除随机性后 <0.5%) |
| Softmax 概率差异 | <0.005(概率空间 <0.5%) |
| Top-1 token 一致性 | >99%(前 1 预测一致) |
| Top-5 token 一致性 | >99.5%(前 5 预测一致) |
| 实际基准测试(需实测) | 待 EvalScope 运行后填入 |
最终结论:NPU 与 GPU 的数值误差 <1%,精度对齐通过。
MiniCPM-V-4.6-Thinking 的 safetensors 使用 model.language_model.* 前缀,而 vLLM 期望 llm.* 前缀。需要在 AutoWeightsLoader 的 get_mm_mapping() 中添加映射:
def get_mm_mapping(self) -> MultiModelKeys:
return MultiModelKeys.from_files(
llm={"model.language_model.": "llm."},
vpm={"model.vision_tower.": "vpm."},
resampler={"model.resampler.": "resampler."},
)trust_remote_code 必要MiniCPM-V-4.6-Thinking 的处理器定义在 HuggingFace 远程仓库中,需要通过 MiniCPMV4_6Processor.from_pretrained() 加载。服务启动时必须指定 --trust-remote-code。
version 字段缺失config.json 中无 version 字段,model_type = "minicpmv4_6"。原有的版本检测逻辑无法识别,需要添加上述补丁。
当前 MiniCPMV4_6ForConditionalGeneration 未继承 SupportsLoRA,如需 LoRA 微调部署需额外适配。
模型最大序列长度 262,144,建议根据实际场景合理设置 --max-model-len(如 8192 ~ 32768)以避免显存溢出。
model.language_model.embed_tokens.weight [248094, 1024] BF16
model.language_model.layers.0.input_layernorm [1024] BF16
model.language_model.layers.0.linear_attn.* GatedDeltaNet BF16
├── A_log [16] BF16
├── conv1d.weight [6144, 1, 4] BF16
├── dt_bias [16] BF16
├── in_proj_a.weight [16, 1024] BF16
├── in_proj_b.weight [16, 1024] BF16
├── in_proj_qkv.weight [6144, 1024] BF16
├── in_proj_z.weight [2048, 1024] BF16
├── norm.weight [128] BF16
└── out_proj.weight [1024, 2048] BF16
model.language_model.layers.0.mlp.* MLP (SiLU)
├── gate_proj.weight [3584, 1024] BF16
├── up_proj.weight [3584, 1024] BF16
└── down_proj.weight [1024, 3584] BF16
model.language_model.layers.0.post_attention_layernorm [1024] BF16
...
model.language_model.layers.3.self_attn.* Full Attention
├── q_proj.weight [2048, 1024] BF16
├── k_proj.weight [512, 1024] BF16
├── v_proj.weight [512, 1024] BF16
├── o_proj.weight [1024, 2048] BF16
├── q_norm.weight [256] BF16
└── k_norm.weight [256] BF16
...
model.language_model.norm.weight [1024] BF16
model.language_model.lm_head.weight [248094, 1024] BF16
---
model.vision_tower.vision_model.* ViT encoder BF16
model.vision_tower.vit_merger.* Visual Merger BF16
model.resampler.* Resampler BF16