Ascend-SACT/GPT-OSS-20B-BF16-based-vllm-ascend
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

摘要:

【背景】

在昇腾上基于vllm部署gpt-oss模型时,发现vllm-ascend主线版本存在断点,未对gpt-oss模型进行适配。gpt-oss模型相比主流开源模型存在以下差异:

  1. Attention部分隔层交叉使用GQA和SWA、引入sink bias;
  2. MoE部分魔改了SwiGLU;
  3. 网络线性层使用了bias结构。

这些差异导致vllm-ascend主线不足以支撑gpt-oss的推理部署,vllm-ascend社区提交的小算子模式PR,推理性能较差(10并发、输入128、输出128下,平均时延17.61s,平均TTFT 319ms,平均TPOT 136ms),无法满足业务需求。

【解决方式】

  1. 融合算子替换,适配sink bias特性:升级CANN版本至8.3.RC1,基于torch_npu.npu_fused_infer_attention_score_v2接口,修改vllm-ascend社区attention计算部分代码,添加gpt-oss所需的SWA和sink bias特性适配;
  2. MOE断点补齐,适配swigluoai和bias特性:修改 vllm-ascend社区fused_moe部分代码,添加对swigluoai和bias结构的支持;
  3. 图模式断点补齐,完成PIECEWISE ACL图模式:修改vllm-ascend图模式相关代码,实现一次捕获,多次重放,缓解host下发瓶颈及NPU运行效率,降低NPU free占比,提升推理性能。

【效果】

通过FIA融合算子使能,叠加图模式优化。gpt-oss-20b平均推理时延降低84.57%,TPS提升6.48倍,TTFT降低31.97%,TPOT降低85.29%。精度测试采用默认的medium思考深度,采样参数"temperature":0,在MMLU所有子测试集的平均mean_acc为83.41%,相比论文公布的84%,精度误差在可接受的范围内。gpt-oss-120b平均推理时延降低94.99%,TPS提升67倍,TTFT降低36.58%,TPOT降低99.04%。在MMLU所有子测试集的平均mean_acc为87.55%,相比论文公布的88%,精度误差在可接受的范围内。

关键词: 【模型-GPT-OSS】【场景-推理】【问题-模型迁移】【问题-性能调优】【框架-vllm-ascend】

1. 背景描述

1.1 硬件环境

型号卡数模型
910B31gpt-oss-20b

1.2 软件版本

软件名版本
CANN8.3.RC1
Python3.11.13
transformers4.57.1
torch2.7.1
torch-npu2.7.1
vllmv0.11.2
链接Commit ID
vllm源码https://github.com/vllm-project/vllm.git275de34170654274616082721348b7edd9741d32
vllm-ascend源码https://github.com/taoyao1221/vllm-ascend/tree/add-gpt-ossc9b64052ee740fda213934b2b39151d00b3fe403

1.3 模型下载地址及模型配置

1.3.1 BF16模型权重地址

模型名称链接
gpt-oss-20b-BF16https://www.modelscope.cn/models/unsloth/gpt-oss-20b-BF16

1.3.2 GPT-OSS模型配置

模型/配置gpt-oss-120Bgpt-oss-20B
总参数量117B21B
激活参数量5.1B3.6B
MLP参数量114.71B19.12B
Attention参数量0.96B0.64B
Embedding&Unembedding参数量1.16B1.16B
位置编码RoPE+YaRNRoPE+YaRN
上下文长度128k128k
num_layers3624
hidden_size28802880
attention_typeGQA (SWA)GQA (SWA)
sliding_window_size128128
num_attention_heads6432
num_kv_heads88
head_dim6464
num_experts12832
num_activated_experts44
mlp_typeSwiGLUSwiGLU

2. 解决方案和结果简述

2.1 问题根因

  1. CANN8.3.RC1之前的版本aclnnFusedInferAttentionScore融合算子不支持learnable_sink特性;
  2. vllm-ascend社区attention计算未进行SWA和sink bias适配;
  3. vllm-ascend社区fused_moe未对swigluoai和bias结构做适配;
  4. 外部生态提交的针对gpt-oss适配的小算子模式PR,频繁的算子下发,带来显著的host瓶颈,NPU空泡严重,推理性能极差。

小算子模式NPU空泡严重:

小算子模式推理性能差:

2.2 解决措施

  1. 升级CANN版本至8.3.RC1,基于torch_npu.npu_fused_infer_attention_score_v2接口,修改vllm-ascend社区attention计算部分代码,添加gpt-oss所需的SWA和sink bias特性适配;
  2. 修改 vllm-ascend社区fused_moe部分代码,添加对swigluoai和bias结构的支持;
  3. 面向gpt-oss模型,完成PIECEWISE ACL图模式适配,实现一次捕获,多次重放,缓解host下发瓶颈,降低NPU free占比,提升推理性能。

2.3 结果

通过FIA融合算子使能,叠加图模式优化。 gpt-oss-20b平均推理时延降低84.57%,TPS提升6.48倍,TTFT降低31.97%,TPOT降低85.29%。精度测试采用默认的medium思考深度,采样参数"temperature":0,在MMLU所有子测试集的平均mean_acc为83.41%,相比论文公布的84%,在合理的误差范围内。 gpt-oss-120b**平均推理时延降低94.99%,TPS提升67倍,TTFT降低36.58%,TPOT降低99.04%。在MMLU所有子测试集的平均mean_acc为​87.55%​,相比论文公布的​88%**​,在合理的误差范围内。

3. 详细方案

3.1 环境准备

3.1.1 CANN镜像及开发容器

  1. 拉取CANN镜像:
docker pull m.daocloud.io/quay.io/ascend/cann:8.3.rc1-910b-ubuntu22.04-py3.11
  1. 启动开发容器:
export IMAGE=m.daocloud.io/quay.io/ascend/cann:8.3.rc1-910b-ubuntu22.04-py3.11
docker run -it -d --net=host \
    --name cann8.3.rc1-gpt-oss \
    --shm-size=1g \
    --privileged \
    --device /dev/davinci_manager \
    --device /dev/devmm_svm \
    --device /dev/hisi_hdc \
    -v /usr/local/dcmi:/usr/local/dcmi \
    -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
    -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
    -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
    -v /etc/ascend_install.info:/etc/ascend_install.info \
    -v /home:/home \
    -it $IMAGE bash
  1. 进入开发容器:
docker exec -it cann8.3.rc1-gpt-oss bash

3.1.2 vllm及evalscop安装

  1. 安装和vllm-ascend配套的vllm版本:
cd vllm
git reset --hard 275de34170654274616082721348b7edd9741d32
pip install -r requirements/build.txt
VLLM_TARGET_DEVICE="empty" pip install .
  1. 安装vllm-ascend:
git clone -b add-gpt-oss https://github.com/taoyao1221/vllm-ascend.git
cd vllm-ascend
pip install -v -e .
  1. 安装eval_scope benchmark 精度及性能测试工具:
pip install evalscope
pip install evalscope[perf] -U

3.2 gpt-oss模型适配

3.2.1 gpt-oss模型attention适配

gpt-oss模型SWA Attention架构分析:

gpt-oss采用了隔层交叉使用SWA和GQA的注意力策略,具体来说就是模型中每一层的attention type在“滑动窗口注意力”和“全局注意力”之间切换。目的是在确保关键全局信息能被有效整合和传递的前提下提高计算效率,实现推理精度与性能的平衡。 gpt-oss模型Attention sinks bias分析:

gpt-oss为每个注意力头设置一个可学习Attention sinks bias,仅用来充当注意力分数计算时SoftMax的分母,以解决标准注意力机制中“强制输出”和“attention sinks”的问题。

vllm-ascend attention调用逻辑(待梳理):

  1. gpt-oss隔层交叉使用GQA和SWA特性适配:

    image 通过self.sliding_window是否为None判断当前层使用全局GQA还是SWA,通过注意力mask确定参与attention计算的token范围。

  2. gpt-oss模型sink特性适配:

    image

    image

    首先从vllm接收传入的sinks参数,并完成成员变量的初始化。

    image

    基于torch_npu.npu_fused_infer_attention_score_v2接口完成SWA和全局GQA的计算,sparse_mode=4表示使用SWA,sparse_mode=3表示使用标准的全局GQA。learnable_sink=self.sinks传参实现sink bias参与attention的计算。

3.2.2 gpt-oss模型MoE MLP适配

gpt-oss模型MoE架构分析:

  1. gpt-oss在专家MLP部分采用的是SwiGLU结构,但对结构进行了一系列细微调整,这部分实现和当前主流开源模型差异较多;
  2. MLP线性层重新引入bias。

修改moe_mlp.py,实现w1_bias、w2_bias、activation参数的传入,以及swigluoai和bias结构的适配。

  1. gpt-oss模型swigluoai适配: image

  2. gpt-oss模型bias结构适配:

    image

    image

3.2.3 Eager模式验证

  1. 离线推理验证脚本,添加enforce_eager=True:
import os
import argparse
from vllm import LLM, SamplingParams

os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0"
os.environ["VLLM_USE_V1"] = "1"
Profile:bool = False

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/home/models/hf_models/gpt-oss-20b-BF16")
    parser.add_argument("--profile_path", type=str, default="/home/taoyao/profile_vllm_output")
    args = parser.parse_args()
    if Profile:
        os.environ["VLLM_TORCH_PROFILER_DIR"] = args.profile_path

    prompts = ["Who are you? Please introduce yourself and analyze your relationship with OpenAI."]
    sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=64)
    llm = LLM(model=args.model_path,
    tensor_parallel_size=1,
    max_model_len=16384,
    max_num_seqs=16,
    max_num_batched_tokens=4096,
    gpu_memory_utilization=0.95,
    # load_format="dummy",
    enforce_eager=True,
    )

    if Profile:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if Profile:
        llm.stop_profile()

    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

if __name__=="__main__":
    main()
  1. 在线推理验证脚本,添加--enforce_eager:
#!/bin/sh
export VLLM_USE_V1=1
HOST=0.0.0.0
PORT=8085
export ASCEND_RT_VISIBLE_DEVICES=4
LOCAL_CKPT_DIR=/home/models/hf_models/gpt-oss-20b-BF16
SERVED_MODEL_NAME=gpt-oss-20b

TIKTOKEN_RS_CACHE_DIR=/home/models/hf_models/gpt-oss-20b-BF16/tiktiken_cache vllm serve $LOCAL_CKPT_DIR \
    --served-model-name $SERVED_MODEL_NAME \
    --gpu-memory-utilization 0.95 \
    --tensor-parallel-size 1 \
    --host $HOST \
    --port $PORT \
    --no-disable-log-requests \
    --max-model-len 16384 \
    --max-num-batched-tokens 4096 \
    --max-num-seqs 16 \
    --enforce_eager
  1. 启动在线推理服务后,执行以下命令进行API调用,测试服务是否正常:
curl --location --request POST 'http://127.0.0.1:8085/v1/chat/completions' \
--header 'Content-Type: application/json' \
--data-raw '{
"model": "gpt-oss-20b",
"messages": [
{"role": "user", "content": "Who are you? Please introduce yourself and analyze your relationship with OpenAI."}
]
}'

3.2.4 ACL图模式适配

  1. 修改moe_mlp部分代码,解决repeat_interleave操作触发H2D导致同步的问题,详见问题5.1.5;
  2. 将bias参数强转为fp32,解决torch_npu.npu_grouped_matmul算子不支持fp16的问题修改后代码:
    gate_up_out = torch_npu.npu_grouped_matmul(
            x=[hidden_states],
            weight=[w1],
            bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None,
            split_item=2,
            group_list_type=group_list_type,
            group_type=0,
            group_list=group_list,
        )[0]
  3. 修改attention部分代码,解决ACL图模式重放时因向上填充到对应档位时,可能导致的shape不匹配问题,详见问题5.1.6。

3.2.4 ACL图模式验证

离线推理验证脚本,添加compilation_config={"cudagraph_mode": "PIECEWISE", "cudagraph_capture_sizes": [16,4,2,1]},开启PIECEWISE ACL图模式,通过cudagraph_capture_sizes手动控制图捕获时的批次大小,建议使用业务常用的BS进行填充:

import os
import argparse
from vllm import LLM, SamplingParams

os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0"
os.environ["VLLM_USE_V1"] = "1"
Profile:bool = False

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/home/models/hf_models/gpt-oss-20b-BF16")
    parser.add_argument("--profile_path", type=str, default="/home/taoyao/profile_vllm_output")
    args = parser.parse_args()
    if Profile:
        os.environ["VLLM_TORCH_PROFILER_DIR"] = args.profile_path

    prompts = ["Who are you? Please introduce yourself and analyze your relationship with OpenAI."]
    sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=64)
    llm = LLM(model=args.model_path,
    tensor_parallel_size=1,
    max_model_len=16384,
    max_num_seqs=16,
    max_num_batched_tokens=4096,
    gpu_memory_utilization=0.95,
    # load_format="dummy",
    # enforce_eager=True,
    compilation_config={"cudagraph_mode": "PIECEWISE", "cudagraph_capture_sizes": [16,4,2,1]},
    )

    if Profile:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if Profile:
        llm.stop_profile()

    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

if __name__=="__main__":
    main()

在线推理验证脚本,添加--compilation_config '{"cudagraph_mode": "PIECEWISE", "cudagraph_capture_sizes": [16,10,4,2,1]}:

#!/bin/sh
export VLLM_USE_V1=1
HOST=0.0.0.0
PORT=8085
export ASCEND_RT_VISIBLE_DEVICES=4
LOCAL_CKPT_DIR=/home/models/hf_models/gpt-oss-20b-BF16
SERVED_MODEL_NAME=gpt-oss-20b

TIKTOKEN_RS_CACHE_DIR=/home/models/hf_models/gpt-oss-20b-BF16/tiktiken_cache vllm serve $LOCAL_CKPT_DIR \
    --served-model-name $SERVED_MODEL_NAME \
    --gpu-memory-utilization 0.95 \
    --tensor-parallel-size 1 \
    --host $HOST \
    --port $PORT \
    --no-disable-log-requests \
    --max-model-len 16384 \
    --max-num-batched-tokens 4096 \
    --max-num-seqs 16 \
    --compilation_config '{"cudagraph_mode": "PIECEWISE", "cudagraph_capture_sizes": [16,10,4,2,1]}'

3.3 性能测试

  1. 启动在线推理服务后,执行以下性能测试脚本:
evalscope perf \
  --parallel 1 10 \
  --number 10 20 \
  --model gpt-oss-20b \
  --url http://127.0.0.1:8085/v1/chat/completions \
  --api openai \
  --dataset random \
  --max-tokens 128 \
  --min-tokens 128 \
  --prefix-length 0 \
  --min-prompt-length 128 \
  --max-prompt-length 128 \
  --tokenizer-path /home/models/hf_models/gpt-oss-20b-BF16 \
  --extra-args '{"ignore_eos": true}'
  1. 小算子模式性能测试结果:
  2. FIA融合算子 Eager模式性能测试结果:
  3. PIECEWISE ACL图模式性能测试结果:
  4. GPT-OSS-20B 性能测试总结: 相比小算子模式,10并发、输入128、输出128下,通过FIA融合算子,叠加PIECEWISE ACL图模式优化后推理性能大大提升,平均推理时延降低84.57%,TPS提升6.48倍,TTFT降低31.97%,TTFT降低85.29%。

3.4 精度测试

gpt-oss论文呈现的精度测试结果,如下图:

  1. 启动在线推理服务后,执行以下MMLU测试集精度测试脚本,可通过--datasets mmlu指定测试集,通过配置 --work-dir和--use-cache进行断点续测:
export MODELSCOPE_CACHE=/home/taoyao/modelscope

evalscope eval \
 --model gpt-oss-20b \
 --api-url http://127.0.0.1:8085/v1 \
 --api-key EMPTY \
 --eval-type server \
 --datasets mmlu \
 --generation-config '{"do_sample":true,"temperature":0}' \
 --eval-batch-size 16 \
 --work-dir ./outputs/20251127_075917 \
 --use-cache true
  1. MMLU测试集精度测试结果:

    image

    测试时采用默认的medium思考深度,采样参数"temperature":0,在MMLU所有子测试集的平均mean_acc为83.41%,相比论文公布的84%,在合理的误差范围内。

4. 解决效果

通过FIA融合算子使能,叠加图模式优化。gpt-oss-20b平均推理时延降低84.57%,TPS提升6.48倍,TTFT降低31.97%,TPOT降低85.29%。精度测试采用默认的medium思考深度,采样参数"temperature":0,在MMLU所有子测试集的平均mean_acc为83.41%,相比论文公布的84%,精度误差在可接受的范围内。

5. 附录

5.1 遇到的问题

5.1.1 输出乱码-需配置tiktoken

需配置tiktoken缓存:

pip install openai-harmony -i https://mirrors.aliyun.com/pypi/simple/
TIKTOKEN_RS_CACHE_DIR=/home/models/hf_models/gpt-oss-20b-BF16/tiktiken_cache python -c 'from openai_harmony import load_harmony_encoding; load_harmony_encoding("HarmonyGptOss")'

5.1.2 FIA_v2算子learnable_sink参数调试报错

在对torch_npu.npu_fused_infer_attention_score_v2接口进行learnable_sink参数验证时报如下错误:拉算子侧同事定位,CANN8.5+torch_npu==2.7.1.post1.dev20251103未复现该问题; 解决过程:

  1. 更新PTA版本torch_npu==2.7.1.post1.dev20251103后问题未解决;
  2. 更新CANN版本:8.3.RC1.alpha003 ---> 8.3.RC1后问题消失。

5.1.3 Eager模式离线推理输出乱码-MoE未适配

  1. 初步怀疑是权重加载问题,通过对比小算子模式和FIA融合算子版本中间tensor,排查从哪一步开始出现精度不一致的问题。逐步排查,q\k\v在进入attention计算之前都是保持一致的,但attention计算的输出不一致: 拉算子侧同事定位,算子侧同事认为FIA算子内部计算会引入误差,但需要进一步计算cos距离相似度才能判断是否在可接收的误差范围内,dump后计算cos距离在可接受范围内(<0.99),暂时排除FIA算子问题。
  2. 采用mstt的msprobe工具进行精度定位,对比fia和小算子模式的dump.json,如下图所示,同小算子模式进行精度比对时,发现执行到moe的mlp层时,精度误差陡然增大。 image 重点分析MoE部分代码,发现vllm-ascend的MoE部分实现,未接收bias参数,并且未支持gpt-oss魔改的swigluoai激活函数。修复后乱码问题消失。

5.1.4 ACL图模式-图捕获时H2D触发同步

repeat_interleave操作触发H2D导致同步问题。

(EngineCore_DP0 pid=209735)   File "/home/taoyao/my_repo/vllm-ascend/vllm_ascend/ops/moe/moe_mlp.py", line 204, in unquant_apply_mlp
(EngineCore_DP0 pid=209735)     experts_indices = experts_indices.repeat_interleave(group_list)
(EngineCore_DP0 pid=209735)                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=209735) RuntimeError: operator():build/CMakeFiles/torch_npu.dir/compiler_depend.ts:47 NPU function error: c10_npu::acl::AclrtSynchronizeStreamWithTimeout(copy_stream), error code is 107027
(EngineCore_DP0 pid=209735) [ERROR] 2025-11-24-10:14:55 (PID:209735, Device:0, RankID:-1) ERR00100 PTA call acl api failed.
(EngineCore_DP0 pid=209735) EE9999: Inner Error!

解决办法,将bias计算融合进torch_npu.npu_grouped_matmul,由于GMM算子bias参数当前不支持fp16,目前强转为float32解决。

gate_up_out = torch_npu.npu_grouped_matmul(
        x=[hidden_states],
        weight=[w1],
        bias=[w1_bias.to(dtype=torch.float32)],
        split_item=2,
        group_list_type=group_list_type,
        group_type=0,
        group_list=group_list,
    )[0]

5.1.5 ACL图模式-replay时因向上填充到对应档位时报错

执行时遇到如下问题,分析是图模式-replay时query因为做了向上填充到对应档位,导致query的token长度和actual_seq_qlen不一致;需要针对性地对query\key\value进行unpad,然后再传给FIA算子;FIA算子计算atten输出后,再填充回output;

5.1.6 ACL图模式编译时报错

vllm-ascend版本更新到v0.11.0RC1之后版本,报图编译错误: 规避方案: 注释掉/usr/local/python3.11.13/lib/python3.11/site-packages/vllm/compilation/wrapper.py 以下代码行:

# if self.vllm_config.compilation_config.cudagraph_mode != \
#     CUDAGraphMode.NONE and "update" in new_code.co_names:
#     import depyf
#     src = depyf.decompile(new_code)
#     msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src  # noqa
#     raise RuntimeError(msg)

5.2 参考及引用

  1. gpt-oss-120b & gpt-oss-20b Model Card
  2. gpt-oss vllm-ascend小算子版本PR
  3. evalscope benchmark测试工具使用指导文档