Ascend-SACT/DeepSeek-V3.2-w8a8
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

本案例以双机PD混布和大EP为例,介绍vllm-ascend框架部署DeepSeek-V3.2-w8a8模型的步骤。

1. 准备权重文件

可以直接下载第三方权重,也可以在DeepSeek官方发布DeepSeek-V3.2权重的基础上做量化。

1.1 从魔乐社区下载w8a8权重文件:

https://modelscope.cn/models/Eco-Tech/DeepSeek-V3.1-w8a8-mtp-QuaRot

1.2 使用msmodelslim工具做w8a8量化:

https://gitcode.com/Ascend/msit/blob/br_release_MindStudio_8.3.0_20261231/msmodelslim/docs/%E6%94%AF%E6%8C%81%E7%9F%A9%E9%98%B5/%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%94%AF%E6%8C%81%E7%9F%A9%E9%98%B5.md

DS3.2原始权重 https://modelscope.cn/models/deepseek-ai/DeepSeek-V3.2

2. 准备基础镜像

使用vllm-ascend 0.13.0rc1版本的镜像。

 A3:quay.io/ascend/vllm-ascend:v0.13.0rc1-a3
 A2:quay.io/ascend/vllm-ascend:v0.13.0rc1

3. 创建容器

A3:

export IMAGE=quay.io/ascend/vllm-ascend:v0.13.0rc1-a3
docker run --privileged \
    --name vllm-ascend \
    --shm-size=1g \
    --net=host \
    --device /dev/davinci0 \
    --device /dev/davinci1 \
    --device /dev/davinci2 \
    --device /dev/davinci3 \
    --device /dev/davinci4 \
    --device /dev/davinci5 \
    --device /dev/davinci6 \
    --device /dev/davinci7 \
    --device /dev/davinci8 \
    --device /dev/davinci9 \
    --device /dev/davinci10 \
    --device /dev/davinci11 \
    --device /dev/davinci12 \
    --device /dev/davinci13 \
    --device /dev/davinci14 \
    --device /dev/davinci15 \
    --device /dev/davinci_manager \
    --device /dev/devmm_svm \
    --device /dev/hisi_hdc \
    -v /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime \
    -v /usr/local/dcmi:/usr/local/dcmi \
    -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
    -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
    -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
    -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
    -v /etc/ascend_install.info:/etc/ascend_install.info \
    -v /root/.cache:/root/.cache \
    -v /home/:/home/ \
    -v /opt/data/:/opt/data/ \
    -v /mnt:/mnt \
    -it $IMAGE bash

A2:

export IMAGE=quay.io/ascend/vllm-ascend:v0.13.0rc1
docker run --privileged \
    --name vllm-ascend \
    --shm-size=1g \
    --net=host \
    --device /dev/davinci0 \
    --device /dev/davinci1 \
    --device /dev/davinci2 \
    --device /dev/davinci3 \
    --device /dev/davinci4 \
    --device /dev/davinci5 \
    --device /dev/davinci6 \
    --device /dev/davinci7 \
    --device /dev/davinci_manager \
    --device /dev/devmm_svm \
    --device /dev/hisi_hdc \
    -v /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime \
    -v /usr/local/dcmi:/usr/local/dcmi \
    -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
    -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
    -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
    -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
    -v /etc/ascend_install.info:/etc/ascend_install.info \
    -v /root/.cache:/root/.cache \
    -v /home/:/home/ \
    -v /opt/data/:/opt/data/ \
    -v /mnt:/mnt \
    -it $IMAGE bash

4. 安装triton-ascend

在容器内,首先安装毕昇编译器,执行如下命令:

BISHENG_NAME="Ascend-BiSheng-toolkit_$(uname -i)_20251225.run"
BISHENG_URL="https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/${BISHENG_NAME}"
wget -O "${BISHENG_NAME}" "${BISHENG_URL}" && chmod a+x "${BISHENG_NAME}" && "./${BISHENG_NAME}" --install && rm "${BISHENG_NAME}"
source /usr/local/Ascend/8.5.0/bisheng_toolkit/set_env.sh

再安装 triton-ascend,执行如下命令:

 wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev20251229-cp311-cp311-manylinux_2_27_$(uname -i).manylinux_2_28_$(uname -i).whl
pip install triton_ascend-3.2.0.dev20251229-cp311-cp311-manylinux_2_27_$(uname -i).manylinux_2_28_$(uname -i).whl

5.部署模型

使用ifconfig在物理机查询nic_name;

主从节点分别设置local_ip,从节点设置主节点ip参数node0_ip;

--max-num-seqs需要根据实际场景设置最大并发;

vllm serve 命令中设置模型权重路径。

nic_name:指定本机的网络接口名称,用于分布式通信中识别物理网络设备,以确保通信库正确绑定到指定网卡,避免多网卡环境下的混淆。在分布式计算中,明确网络接口可减少通信延迟和错误。

local_ip:定义本机的IP地址,作为分布式节点通信的地址标识。local_ip用于设置环境变量(如HCCL_IF_IP)、传递给vllm serve命令的--data-parallel-address参数,确保节点间能正确发现和连接。

HCCL_IF_IP:设置华为集合通信库(HCCL)使用的IP地址。确保HCCL操作绑定到指定IP,避免使用错误网络接口。

VLLM_USE_MODELSCOPE:指示vllm_ascend框架可以从ModelScope平台加载模型,而非本地路径。当设置此变量时,模型路径可简写为vllm-ascend/DeepSeek-V3.1-W8A8,我们仍然使用本地存储的权重文件。

GLOO_SOCKET_IFNAME:为Gloo通信后端(PyTorch的默认CPU通信库)指定网络接口名称。

TP_SOCKET_IFNAME:类似HCCL,确保TP通信绑定到指定网卡。

HCCL_SOCKET_IFNAME:直接指定HCCL库的套接字接口名称,功能类似HCCL_IF_IP。

HCCL_OP_EXPANSION_MODE:配置HCCL操作扩展模式。

PYTORCH_NPU_ALLOC_CONF=expandable_segments:True:控制PyTorch在昇腾NPU上的内存分配行为。expandable_segments:True允许内存段动态扩展,减少内存碎片,用于解决OOM错误。

OMP_PROC_BIND:控制OpenMP线程是否绑定到CPU核心。false表示不绑定,允许线程在核心间迁移,提高灵活性但可能增加延迟。常与OMP_NUM_THREADS配合,用于平衡计算负载。

OMP_NUM_THREADS:设置OpenMP使用的线程数,影响CPU侧并行计算。

HCCL_BUFFSIZE:定义HCCL通信缓冲区大小,用于优化多机通信。根据模型规模调整。

VLLM_ASCEND_ENABLE_MLAPO:用于开启 vLLM 推理引擎与昇腾 MLAPO(Machine Learning Automatic Performance Optimization)框架的联动优化功能。其核心作用是让 vLLM 能够借助 MLAPO 的自动调优能力,动态优化 NPU 资源的使用策略,从而提升推理性能和效率。

VLLM_ASCEND_ENABLE_FLASHCOMM1:FlashComm 通信优化技术,提高多节点间推理的通信效率。传统路径:NPU计算 → 系统内存 → CPU处理 → 网络 → 对端CPU → 对端系统内存 → 对端NPU;FlashComm优化路径:NPU计算 → RDMA直接通信 → 对端NPU。

5.1 双机PD混合部署

5.1.1 主节点node0拉起服务化

nic_name="enp48s3u1u1"
local_ip=x.x.x.x
# AIV
export HCCL_OP_EXPANSION_MODE="AIV"
export HCCL_IF_IP=$local_ip
export GLOO_SOCKET_IFNAME=$nic_name
export TP_SOCKET_IFNAME=$nic_name
export HCCL_SOCKET_IFNAME=$nic_name
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=10
export VLLM_USE_V1=1
export HCCL_BUFFSIZE=200
export VLLM_ASCEND_ENABLE_MLAPO=1
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1

vllm serve /data/DeepSeek-V3.2-1201-W8A8 \
--host 0.0.0.0 \
--port 9100 \
--data-parallel-size 2 \
--data-parallel-size-local 1 \
--data-parallel-address 141.61.39.105 \
--data-parallel-rpc-port 12890 \
--tensor-parallel-size 16 \
--quantization ascend \
--seed 1024 \
--served-model-name dsv32 \
--enable-expert-parallel \
--max-num-seqs 16 \
--max-model-len 8192 \
--max-num-batched-tokens 4096 \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.92 \
--speculative-config '{"num_speculative_tokens": 2, "method": "deepseek_mtp"}'

5.1.2 从节点node1拉起服务化

nic_name="enp48s3u1u1"
local_ip=x.x.x.x
# AIV
export HCCL_OP_EXPANSION_MODE="AIV"
export HCCL_IF_IP=$local_ip
export GLOO_SOCKET_IFNAME=$nic_name
export TP_SOCKET_IFNAME=$nic_name
export HCCL_SOCKET_IFNAME=$nic_name
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=10
export VLLM_USE_V1=1
export HCCL_BUFFSIZE=200
export VLLM_ASCEND_ENABLE_MLAPO=1
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1

vllm serve /data/DeepSeek-V3.2-1201-W8A8 \
--host 0.0.0.0 \
--port 9100 \
--headless \
--data-parallel-size 2 \
--data-parallel-size-local 1 \
--data-parallel-start-rank 1 \
--data-parallel-address 141.61.39.105 \
--data-parallel-rpc-port 12890 \
--tensor-parallel-size 16 \
--quantization ascend \
--seed 1024 \
--served-model-name dsv32 \
--enable-expert-parallel \
--max-num-seqs 16 \
--max-model-len 8192 \
--max-num-batched-tokens 4096 \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.92 \
--speculative-config '{"num_speculative_tokens": 2, "method": "deepseek_mtp"}'

5.2 大EP部署

以A3四机1P1D为例,大EP部署步骤如下。不同的服务器配置修改对应的nic_name、data-parallel参数、prefill和decode的tp_size、dp_size。

5.2.1 vllm-ascend 大EP服务化配置

在每个节点上准备服务化拉起脚本,命名为run_dp_template.sh: 每个P节点的服务化脚本如下:

nic_name="enp48s3u1u1c2" # change to your own nic name
local_ip=x.x.x.x # change to your own ip
export HCCL_OP_EXPANSION_MODE="AIV"
export HCCL_IF_IP=$local_ip
export GLOO_SOCKET_IFNAME=$nic_name
export TP_SOCKET_IFNAME=$nic_name
export HCCL_SOCKET_IFNAME=$nic_name
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=10
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export VLLM_USE_V1=1
export HCCL_BUFFSIZE=256
export VLLM_TORCH_PROFILER_DIR="./vllm_profile"
export VLLM_TORCH_PROFILER_WITH_STACK=0
export ASCEND_AGGREGATE_ENABLE=1
export ASCEND_TRANSPORT_PRINT=1
export ACL_OP_INIT_MODE=1
export ASCEND_A3_ENABLE=1
export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=300000
export ASCEND_RT_VISIBLE_DEVICES=$1
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1

vllm serve  /data/DeepSeek-V3.2-1201-W8A8 \
    --host 0.0.0.0 \
    --port $2 \
    --data-parallel-size $3 \
    --data-parallel-rank $4 \
    --data-parallel-address $5 \
    --data-parallel-rpc-port $6 \
    --tensor-parallel-size $7 \
    --enable-expert-parallel \
    --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}' \
    --seed 1024 \
    --served-model-name dsv32 \
    --max-model-len 68000 \
    --max-num-batched-tokens 32550 \
    --trust-remote-code \
    --max-num-seqs 64 \
    --gpu-memory-utilization 0.82 \
    --quantization ascend \
    --enforce-eager \
    --no-enable-prefix-caching \
    --kv-transfer-config \
    '{"kv_connector": "MooncakeConnectorV1",
    "kv_role": "kv_producer",
    "kv_port": "30000",
    "engine_id": "0",
    "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
    "kv_connector_extra_config": {
                "use_ascend_direct": true,
                "prefill": {
                        "dp_size": 2,
                        "tp_size": 16
                },
                "decode": {
                        "dp_size": 8,
                        "tp_size": 4
                }
        }
    }'

每个D节点的服务化脚本如下:

nic_name="enp48s3u1u1c2" # change to your own nic name
local_ip=x.x.x.x # change to your own ip
export HCCL_OP_EXPANSION_MODE="AIV"
export HCCL_IF_IP=$local_ip
export GLOO_SOCKET_IFNAME=$nic_name
export TP_SOCKET_IFNAME=$nic_name
export HCCL_SOCKET_IFNAME=$nic_name
#Mooncake
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=10
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export VLLM_USE_V1=1
export HCCL_BUFFSIZE=256
export VLLM_TORCH_PROFILER_DIR="./vllm_profile"
export VLLM_TORCH_PROFILER_WITH_STACK=0
export ASCEND_AGGREGATE_ENABLE=1
export ASCEND_TRANSPORT_PRINT=1
export ACL_OP_INIT_MODE=1
export ASCEND_A3_ENABLE=1
export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=300000
export TASK_QUEUE_ENABLE=1
export ASCEND_RT_VISIBLE_DEVICES=$1
export VLLM_ASCEND_ENABLE_MLAPO=1

vllm serve /data/DeepSeek-V3.2-1201-W8A8 \
    --host 0.0.0.0 \
    --port $2 \
    --data-parallel-size $3 \
    --data-parallel-rank $4 \
    --data-parallel-address $5 \
    --data-parallel-rpc-port $6 \
    --tensor-parallel-size $7 \
    --enable-expert-parallel \
    --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}' \
    --seed 1024 \
    --served-model-name dsv32 \
    --max-model-len 68000 \
    --max-num-batched-tokens 12 \
    --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes":[3, 6, 9, 12]}' \
    --trust-remote-code \
    --max-num-seqs 4 \
    --gpu-memory-utilization 0.95 \
    --no-enable-prefix-caching \
    --async-scheduling \
    --quantization ascend \
    --kv-transfer-config \
    '{"kv_connector": "MooncakeConnectorV1",
    "kv_role": "kv_consumer",
    "kv_port": "30100",
    "engine_id": "1",
    "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
    "kv_connector_extra_config": {
                "use_ascend_direct": true,
                "prefill": {
                        "dp_size": 2,
                        "tp_size": 16
                },
                "decode": {
                        "dp_size": 8,
                        "tp_size": 4
                }
        }
    }'

5.2.2 vllm-ascend 启动大EP服务化命令

在所有节点上准备launch_online_dp.py脚本:

import argparse
import multiprocessing
import os
import subprocess
import sys

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dp-size",
        type=int,
        required=True,
        help="Data parallel size."
    )
    parser.add_argument(
        "--tp-size",
        type=int,
        default=1,
        help="Tensor parallel size."
    )
    parser.add_argument(
        "--dp-size-local",
        type=int,
        default=-1,
        help="Local data parallel size."
    )
    parser.add_argument(
        "--dp-rank-start",
        type=int,
        default=0,
        help="Starting rank for data parallel."
    )
    parser.add_argument(
        "--dp-address",
        type=str,
        required=True,
        help="IP address for data parallel master node."
    )
    parser.add_argument(
        "--dp-rpc-port",
        type=str,
        default=12345,
        help="Port for data parallel master node."
    )
    parser.add_argument(
        "--vllm-start-port",
        type=int,
        default=9000,
        help="Starting port for the engine."
    )
    return parser.parse_args()

args = parse_args()
dp_size = args.dp_size
tp_size = args.tp_size
dp_size_local = args.dp_size_local
if dp_size_local == -1:
    dp_size_local = dp_size
dp_rank_start = args.dp_rank_start
dp_address = args.dp_address
dp_rpc_port = args.dp_rpc_port
vllm_start_port = args.vllm_start_port

def run_command(visiable_devices, dp_rank, vllm_engine_port):
    command = [
        "bash",
        "./run_dp_template.sh",
        visiable_devices,
        str(vllm_engine_port),
        str(dp_size),
        str(dp_rank),
        dp_address,
        dp_rpc_port,
        str(tp_size),
    ]
    subprocess.run(command, check=True)

if __name__ == "__main__":
    template_path = "./run_dp_template.sh"
    if not os.path.exists(template_path):
        print(f"Template file {template_path} does not exist.")
        sys.exit(1)

    processes = []
    num_cards = dp_size_local * tp_size
    for i in range(dp_size_local):
        dp_rank = dp_rank_start + i
        vllm_engine_port = vllm_start_port + i
        visiable_devices = ",".join(str(x) for x in range(i * tp_size, (i + 1) * tp_size))
        process = multiprocessing.Process(target=run_command,
                                        args=(visiable_devices, dp_rank,
                                                vllm_engine_port))
        processes.append(process)
        process.start()

    for process in processes:
        process.join()

每个节点使用如下命令分别拉起服务:

预填充节点 0

# change ip
python launch_online_dp.py --dp-size 2 --tp-size 16 --dp-size-local 1 --dp-rank-start 0 --dp-address {P0 ip} --dp-rpc-port 12890 --vllm-start-port 9100

Prefill 节点 1

# change ip
python launch_online_dp.py --dp-size 2 --tp-size 16 --dp-size-local 1 --dp-rank-start 1 --dp-address {P0 ip} --dp-rpc-port 12890 --vllm-start-port 9100

解码节点 0

# change ip
python launch_online_dp.py --dp-size 8 --tp-size 4 --dp-size-local 4 --dp-rank-start 0 --dp-address {D0 ip} --dp-rpc-port 12890 --vllm-start-port 9100

解码节点 1

# change ip
python launch_online_dp.py --dp-size 8 --tp-size 4 --dp-size-local 4 --dp-rank-start 4 --dp-address {D0 ip} --dp-rpc-port 12890 --vllm-start-port 9100

在大EP的所有节点执行以上命令,启动在线服务化。

上面启动命令中通过--dp-address设置了四机大EP的1P1D。

如果设置2P1D,需要修改如下:

1、节点拉起服务化命令:将Prefill node1 的启动命令中的--dp-address设置为Prefill node1的IP;

2、第二个P节点的run_dp_template.sh中, "engine_id": "1", "kv_port": "30100";

3、两个D节点的run_dp_template.sh中, "engine_id": "2", "kv_port": "30200"。

其他大EP的配置,节点拉起服务化命令中的--dp-address、run_dp_template.sh中的engine_id和kv_port也做类似修改,给P节点与D节点区分engine_id和kv_port。

5.2.3 启动request转发节点

准备两个文件:dp_load_balance_proxy_server.py和toy_proxy.sh。

准备dp_load_balance_proxy_server.py 文件:

# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

# SPDX-License-Identifier: Apache-2.0

import argparse
import asyncio
import heapq
import os
import sys
from contextlib import asynccontextmanager
from typing import List

import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger

logger = init_logger(__name__)

# Add uvloop for faster event loop if available
try:
    import uvloop
    asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
    pass


class ServerState:

    def __init__(self, host, port):
        self.host = host
        self.port = port
        self.url = f'http://{host}:{port}/v1'
        self.client = httpx.AsyncClient(timeout=None,
                                        base_url=self.url,
                                        limits=httpx.Limits(
                                            max_connections=100000,
                                            max_keepalive_connections=100000))
        self.active_tokens = 0
        self.active_kv_cache = 0  # Only for prefiller
        self.active_requests = 0  # Number of active requests
        self.aborted_requests = set()  # Track aborted requests
        # Removed individual server lock - will use global locks instead


class ProxyState:

    def __init__(self, prefiller_instances, decoder_instances):
        self.prefillers: List[ServerState] = [
            ServerState(h, p) for h, p in prefiller_instances
        ]
        self.decoders: List[ServerState] = [
            ServerState(h, p) for h, p in decoder_instances
        ]
        self.req_to_prefiller = {}
        self.req_id_lock = asyncio.Lock()
        self.req_id_counter = 0
        # Removed selection locks - no longer needed for synchronous methods

        # Initialize priority queues for efficient server selection
        # Each entry is (priority_score, server_index, server_reference)
        # Lower priority score = higher priority (less loaded)
        self.prefiller_heap = [(0, i, server)
                               for i, server in enumerate(self.prefillers)]
        self.decoder_heap = [(0, i, server)
                             for i, server in enumerate(self.decoders)]
        heapq.heapify(self.prefiller_heap)
        heapq.heapify(self.decoder_heap)

    def _update_prefiller_priority(self, server_idx: int):
        """Update the priority of a prefiller server in the heap."""
        server = self.prefillers[server_idx]
        # Priority based on active_tokens and active_kv_cache
        priority = server.active_tokens + server.active_kv_cache * 0.3
        # Remove old entry and add new one
        self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
                               if i != server_idx]
        heapq.heappush(self.prefiller_heap, (priority, server_idx, server))

    def _update_decoder_priority(self, server_idx: int):
        """Update the priority of a decoder server in the heap."""
        server = self.decoders[server_idx]
        priority = server.active_tokens
        # Remove old entry and add new one
        self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
                             if i != server_idx]
        heapq.heappush(self.decoder_heap, (priority, server_idx, server))

    def abort_prefiller_request(self, server_idx: int,
                                request_id):  # Changed to synchronous
        """
        Mark a request as aborted. This will helps to release kv cache in
        prefiller node.
        """
        # No lock needed - atomic operation
        self.prefillers[server_idx].aborted_requests.add(request_id)

    def aquire_aborted_prefiller_requests(
            self, server_idx: int):  # Changed to synchronous
        """
        Get the set of aborted requests and clear it.
        This is used to release kv cache in prefiller node.
        """
        # No lock needed - atomic operation
        aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
        self.prefillers[server_idx].aborted_requests.clear()
        return aborted_requests

    async def next_req_id(self):
        async with self.req_id_lock:
            self.req_id_counter += 1
            return str(self.req_id_counter)

    def select_prefiller(self, token_count):  # Changed to synchronous
        # No lock needed - entire function is atomic
        if not self.prefiller_heap:
            raise RuntimeError("No prefiller servers available")

        priority, chosen, server = heapq.heappop(self.prefiller_heap)

        # Update the chosen server atomically
        self.prefillers[chosen].active_tokens += token_count
        self.prefillers[chosen].active_kv_cache += token_count

        # Update priority and re-add to heap
        self._update_prefiller_priority(chosen)

        return chosen

    def release_prefiller(self, idx, token_count):  # Changed to synchronous
        # No lock needed - atomic operation
        self.prefillers[idx].active_tokens -= token_count
        # Update priority queue after releasing
        self._update_prefiller_priority(idx)

    def release_prefiller_kv(self, idx, token_count):  # Changed to synchronous
        # No lock needed - atomic operation
        if self.prefillers[idx].active_kv_cache > 0:
            self.prefillers[idx].active_kv_cache -= token_count
        # Update priority queue after releasing
        self._update_prefiller_priority(idx)

    def select_decoder(self, token_count):  # Changed to synchronous
        # No lock needed - entire function is atomic
        if not self.decoder_heap:
            raise RuntimeError("No decoder servers available")

        priority, chosen, server = heapq.heappop(self.decoder_heap)

        # Update the chosen server atomically
        self.decoders[chosen].active_tokens += token_count

        # Update priority and re-add to heap
        self._update_decoder_priority(chosen)

        return chosen

    def release_decoder(self, idx, token_count):  # Changed to synchronous
        # No lock needed - atomic operation
        self.decoders[idx].active_tokens -= token_count
        # Update priority queue after releasing
        self._update_decoder_priority(idx)

    # Omni_infer's calculate_input_scores function
    def calculate_prefill_scores(self, request_length: int) -> float:
        length_score = request_length / 4.0
        input_score = length_score * 0.0345 + 120.0745
        return input_score

    def calculate_decode_scores(self, request_length: int) -> float:
        return request_length


proxy_state = None


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--prefiller-hosts",
                        type=str,
                        nargs="+",
                        default=["localhost"])
    parser.add_argument("--prefiller-ports",
                        type=int,
                        nargs="+",
                        default=[8001])
    parser.add_argument("--decoder-hosts",
                        type=str,
                        nargs="+",
                        default=["localhost"])
    parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
    parser.add_argument("--max-retries",
                        type=int,
                        default=3,
                        help="Maximum number of retries for HTTP requests")
    parser.add_argument(
        "--retry-delay",
        type=float,
        default=0.001,
        help="Base delay (seconds) for exponential backoff retries")
    args = parser.parse_args()
    if len(args.prefiller_hosts) != len(args.prefiller_ports):
        raise ValueError(
            "Number of prefiller hosts must match number of prefiller ports")
    if len(args.decoder_hosts) != len(args.decoder_ports):
        raise ValueError(
            "Number of decoder hosts must match number of decoder ports")
    args.prefiller_instances = list(
        zip(args.prefiller_hosts, args.prefiller_ports))
    args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
    return args


@asynccontextmanager
async def lifespan(app: FastAPI):
    global proxy_state
    proxy_state = ProxyState(global_args.prefiller_instances,
                             global_args.decoder_instances)
    print(
        f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
    )
    yield
    for p in proxy_state.prefillers:
        await p.client.aclose()
    for d in proxy_state.decoders:
        await d.client.aclose()


app = FastAPI(lifespan=lifespan)


async def send_request_to_service(client: httpx.AsyncClient,
                                  prefiller_id: int,
                                  endpoint: str,
                                  req_data: dict,
                                  request_id: str,
                                  max_retries: int = 3,
                                  base_delay: float = 0.2):
    aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
        prefiller_id)
    req_data = req_data.copy()
    req_data['kv_transfer_params'] = {
        "do_remote_decode": True,
        "do_remote_prefill": False,
        "remote_engine_id": None,
        "remote_block_ids": None,
        "remote_host": None,
        "remote_port": None,
        "aborted_request": list(aborted_requests),
    }
    req_data["stream"] = False
    req_data["max_tokens"] = 1
    if "stream_options" in req_data:
        del req_data["stream_options"]
    headers = {
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
        "X-Request-Id": request_id
    }
    last_exc = None
    for attempt in range(1, max_retries + 1):
        try:
            response = await client.post(endpoint,
                                         json=req_data,
                                         headers=headers)
            response.raise_for_status()
            return response
        except (httpx.RequestError, httpx.HTTPStatusError) as e:
            logger.warning(
                f"Attempt {attempt} failed for {endpoint}: {str(e)}")
            last_exc = e
            if attempt < max_retries:
                await asyncio.sleep(base_delay * (2**(attempt - 1)))
            else:
                logger.error(
                    f"All {max_retries} attempts failed for {endpoint}.")
                raise last_exc


async def stream_service_response_with_retry(client: httpx.AsyncClient,
                                             endpoint: str,
                                             req_data: dict,
                                             request_id: str,
                                             max_retries: int = 3,
                                             base_delay: float = 0.2):
    headers = {
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
        "X-Request-Id": request_id
    }
    for attempt in range(1, max_retries + 1):
        try:
            async with client.stream("POST",
                                     endpoint,
                                     json=req_data,
                                     headers=headers) as response:
                response.raise_for_status()
                first_chunk_sent = False
                async for chunk in response.aiter_bytes():
                    first_chunk_sent = True
                    yield chunk
                return  # Success, exit after streaming
        except (httpx.RequestError, httpx.HTTPStatusError) as e:
            if attempt < max_retries:
                logger.warning(
                    f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
                )
                await asyncio.sleep(base_delay * (2**(attempt - 1)))
            else:
                logger.error(
                    f"All {max_retries} attempts failed for streaming {endpoint}."
                )
                raise e
        except Exception as e:
            # If any chunk has been sent, do not retry, just log and drop
            if 'first_chunk_sent' in locals() and first_chunk_sent:
                logger.error(
                    f"Streaming to client interrupted after response started: {str(e)}"
                )
                return
            else:
                if attempt < max_retries:
                    logger.warning(
                        f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
                    )
                    await asyncio.sleep(base_delay * (2**(attempt - 1)))
                else:
                    logger.error(
                        f"All {max_retries} attempts failed for streaming {endpoint}."
                    )
                    raise e


async def _handle_completions(api: str, request: Request):
    try:
        req_data = await request.json()
        req_body = await request.body()
        request_length = len(req_body)
        prefiller_score = proxy_state.calculate_prefill_scores(request_length)
        logger.debug(
            f"Request length: {request_length}, Prefiller score: {prefiller_score}"
        )
        request_id = await proxy_state.next_req_id()
        # Select prefiller
        prefiller_idx = proxy_state.select_prefiller(prefiller_score)
        prefiller = proxy_state.prefillers[prefiller_idx]
        # Send request to prefiller
        response = await send_request_to_service(
            prefiller.client,
            prefiller_idx,
            api,
            req_data,
            request_id,
            max_retries=global_args.max_retries,
            base_delay=global_args.retry_delay)
        proxy_state.release_prefiller(prefiller_idx, prefiller_score)
        response_json = response.json()
        kv_transfer_params = response_json.get('kv_transfer_params', {})
        if kv_transfer_params:
            req_data["kv_transfer_params"] = kv_transfer_params
        # Select decoder
        decoder_score = proxy_state.calculate_decode_scores(request_length)
        logger.debug("Decoder score: %f", decoder_score)
        # Use the prefiller's kv_transfer_params to select decoder
        decoder_idx = proxy_state.select_decoder(decoder_score)
        decoder = proxy_state.decoders[decoder_idx]
        logger.debug("Using %s %s", prefiller.url, decoder.url)
        # Stream response from decoder
        released_kv = False

        async def generate_stream():
            nonlocal released_kv
            # Only one await per chunk, minimal logic in loop
            try:
                async for chunk in stream_service_response_with_retry(
                        decoder.client,
                        api,
                        req_data,
                        request_id=request_id,
                        max_retries=global_args.max_retries,
                        base_delay=global_args.retry_delay):
                    if not released_kv and chunk:
                        proxy_state.release_prefiller_kv(
                            prefiller_idx, prefiller_score)
                        released_kv = True
                    yield chunk
            except Exception as e:
                logger.error(
                    f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
                )
                proxy_state.abort_prefiller_request(prefiller_idx, request_id)
                proxy_state.release_prefiller_kv(prefiller_idx,
                                                 prefiller_score)

            # After streaming done, release tokens
            proxy_state.release_decoder(decoder_idx, decoder_score)

        return StreamingResponse(generate_stream(),
                                 media_type="application/json")
    except Exception as e:
        import traceback
        exc_info = sys.exc_info()
        print("Error occurred in disagg prefill proxy server"
              f" - {api} endpoint")
        print(e)
        print("".join(traceback.format_exception(*exc_info)))
        raise


@app.post("/v1/completions")
async def handle_completions(request: Request):
    return await _handle_completions("/completions", request)


@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
    return await _handle_completions("/chat/completions", request)


@app.get("/healthcheck")
async def healthcheck():
    return {
        "status": "ok",
        "prefill_instances": len(proxy_state.prefillers),
        "decode_instances": len(proxy_state.decoders)
    }


if __name__ == '__main__':
    global global_args
    global_args = parse_args()
    import uvicorn
    uvicorn.run(app, host=global_args.host, port=global_args.port)

准备toy_proxy.sh文件:

python dp_load_balance_proxy_server.py \
  --port 8000 \
  --host 0.0.0.0 \
  --prefiller-hosts \
     {P0 ip} \
     {P1 ip} \
  --prefiller-ports  \
    9100 \
    9100 \
  --decoder-hosts \
    {D0 ip}  \
    {D0 ip}  \
    {D0 ip}  \
    {D0 ip}  \
    {D1 ip}  \
    {D1 ip}  \
    {D1 ip}  \
    {D1 ip}  \
  --decoder-ports  \
    9100 9101 9102 9103\
    9100 9101 9102 9103\

如果设置2P1D,则toy_proxy文件如下:

python dp_load_balance_proxy_server.py \
  --port 8000 \
  --host 80.6.1.13 \
  --prefiller-hosts {P0 ip} {P0 ip} {P1 ip} {P1 ip} \
  --prefiller-ports 9100 9101 9100 9101 \
  --decoder-hosts {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip} {D0 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip}  {D1 ip} \
  --decoder-ports 9100 9101 9102 9103 9104 9105 9106 9107 9108 9109 9110 9111 9112 9113 9114 9115 9100 9101 9102 9103 9104 9105 9106 9107 9108 9109 9110 9111 9112 9113 9114 9115

5.3 curl验证

向任意混布或PD节点发送curl请求,验证在线服务部署成功。

curl http://0.0.0.0:9100/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "dsv32",
        "prompt": "The future of AI is",
        "max_tokens": 50,
        "temperature": 0
    }'

向转发节点发送curl请求,验证request转发负载均衡生效。

curl http://0.0.0.0:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "dsv32",
        "prompt": "The future of AI is",
        "max_tokens": 50,
        "temperature": 0
    }'

6. 说明

6.1 特性说明

【flashcomm】

通过设置环境变量 export VLLM_ASCEND_ENABLE_FLASHCOMM1=1 开启,可优化 TTFT 首 token 时延。仅需在 P 节点开启,D 节点无需开启;在混布场景下,flashcomm 与 full_decode_only 图模式无法同时启用,原因是 flashcomm 中的 broadcast 异步流未归还。在当前 v0.13.0rc1 版本中,建议混布场景下不开启 flashcomm 特性。

【mlapo 大融合算子】

通过设置环境变量 export VLLM_ASCEND_ENABLE_MLAPO=1 开启,该特性将 attention 前处理过程进行融合,显著节省了算子下发时间,可降低 TPOT 5-10ms。由于该算子存在 token 数小于 1024 的限制,因此仅能在 D 节点开启,若在 P 节点开启会导致报错。

【mtp】

通过 --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}' 开启,P 节点和 D 节点均需开启。num_speculative_tokens 指定一次推理额外生成的最大 token 数,开启后可降低 TPOT;DeepSeek-V3.2-w8a8 正式版的采信率略高于 exp 版本,使用 vllm-benchmark 的随机数据集时,其采信率也高于 aisbench 测试的 gsm8k 仿真数据集。

【异步调度】

通过设置 --async-scheduling 开启,仅在 D 节点生效;该特性能够异步执行 mtp 草稿模型的 input_prepare 操作,从而缩短 TPOT 时延;此特性依赖 rejection sampler 的 triton 算子,必须在环境中安装 triton-ascend 才能真正发挥作用,否则开启后可能导致性能下降。

【lightning_indexer 和 sfa 融合算子迁移】

该特性在 vllm-ascend 框架中生效。目前发现,vllm-ascend v0.12.0rc1 版本在 A2 机器上会出现 lightning_indexer 丢失的报错,此问题已通过 PR #5082 修复并合入,因此务必使用 v0.13.0rc1 版本。

【full_decode_only 图模式】

通过 --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}' 切换至图模式,该模式可将 decode 阶段的 attention 操作纳入计算图,有效降低 TPOT 时延。

6.2 参数调整

【max_model_len】

该参数应设置为 max_input_len 与 max_output_len 之和,在显存紧张的场景下尤为关键。若设置过大,kv_cache 会占用更多显存空间,可能导致服务启动失败。

【max_num_batched_tokens】

目前推荐开启 chunked prefill 特性,开启后此参数可设为任意值,但通常不建议设置过小或过大。设置过小会导致一条请求被切割多次、调度多次,从而劣化性能;设置过大则会使推理过程中的激活值过大,容易引发 OOM。因此,需综合考虑选取合适的值。

【max_num_seqs】

该参数表示一次调度的最多请求数。在 rc1 版本中,D 节点建议设置较小的值,目前已发现 D 节点在多 token 场景下存在服务挂死的报错,相关问题正在定位中。

【gpu-memory-utilization】

该参数用于设置显存利用率。P 节点若显存压力较大,可设置为 0.8-0.9;D 节点显存压力较小,可设置为 0.9-0.95。此配置若设置过大,OOM 风险会增加;若设置过小,服务启动时 kv_cache 可能分配不足。

【cudagraph_capture_sizes】

cudagraph_capture_sizes 需在设置 cudagraph_mode:FULL_DECODE_ONLY 后生效。可设置多个档位值,按从小到大排列,其中最大值 = (mtp 数 + 1) * max_num_seqs,其他值按 (mtp 数 + 1) 的倍数递减。通过预编译不同并发度下的计算图,可大幅减少运行时算子下发的开销,从而提升推理性能,该参数在 Decode 节点设置生效。

6.3 配置说明

在大 EP 部署中,P 节点和 D 节点参数存在差异,原因如下:

  1. prefill 阶段属于计算密集型,而 decode 阶段属于访存密集型。P 节点因处理 prompt,token 数量大,无法纳入计算图,通常以单算子模式启动;D 节点因 token 数量少,可以入图(即设置 "cudagraph_mode":"FULL_DECODE_ONLY"),通常以图模式启动。

  2. 对于 A3 的一卡双 die,将其视为具备 A2 双倍并行能力进行处理,并配置相应的张量并行与数据并行。

  3. 如果业务场景需要处理超大输入序列,需尝试提高显存利用率,或增加 P 节点的服务器数量。

  4. 如果业务场景关注吞吐,应将 max-model-len 设置为满足业务需求的最小值,并在显存不发生 OOM 的前提下,尽可能设置较大的 max-num-batched-tokens,以积攒 token 数量,提高计算利用率和吞吐量;如果业务场景关注低时延,则 max-num-batched-tokens 不宜设置过大。