| 环境配置 | 配置说明 |
|---|---|
| 硬件配置 | Atlas A2 910B |
| CANN版本 | 8.3rc1 |
| 推理框架 | mindIE2.3 |
| 部署方式 | 8卡 部署 |
安装下载工具
pip install modelscope
下载模型权重
modelscope download --model Qwen/Qwen-Image-2512 --local_dir /weight/Qwen-Image-2512
mindIE镜像文件(下载mindIE官方镜像) mindie_2.3.0-py311-openeuler24.03:A2.0.2
pip install diffusers==0.35.1
pip install transformers==4.52.4
pip install yunchang==0.6.0git clone https://modelers.cn/MindIE/Qwen-Image.git && cd Qwen-ImageDIFFUSERS_PATH=$(python -c "import diffusers; import os; print(os.path.dirname(diffusers.__file__))") cp -r pipeline_qwenimage.py "$DIFFUSERS_PATH/pipelines/qwenimage/pipeline_qwenimage.py" cp -r transformer_qwenimage.py "$DIFFUSERS_PATH/models/transformers/transformer_qwenimage.py"使用910B上的8卡进行在线推理服务部署。
如果需要开启在线推理服务,环境配置同第3章所示,需要将run.sh 所调用的 run.py文件做一定的适配修改。
pip install flask将原 run_cfg_usp.py 的离线推理代码修改为如下使用flask框架的在线推理服务,避免每次推理都要重新加载权重,原run_cfg_usp.py 的内容请做好备份。 注:服务IP和端口号在run_flask_app()函数中设定,代码第511行,注意修改为实际环境的IP和端口号;
import sys
import os
import json
import argparse
import functools
import time
import torch
import torch_npu
import threading
import queue
import io
from flask import Flask, request, send_file, jsonify
#-------------------解决 diffuser 0.35.1 torch2.1 报错----------------
def custom_op(
name,
fn=None,
/,
*,
mutates_args,
device_types=None,
schema=None,
tags=None,
):
def decorator(func):
return func
if fn is not None:
return decorator(fn)
return decorator
def register_fake(
op,
fn=None,
/,
*,
lib=None,
_stacklevel: int = 1,
allow_override: bool = False,
):
def decorator(func):
return func
if fn is not None:
return decorator(fn)
return decorator
torch.library.custom_op = custom_op
torch.library.register_fake = register_fake
#-----------------------------------------------------------------------
import torch.distributed as dist
from typing import Optional, Tuple, Union, List, Dict, Any
import numpy as np
import gc # 新增:垃圾回收模块
from PIL import Image
from mindiesd import CacheConfig, CacheAgent
from mindiesd import attention_forward
# 缓存配置开关(从环境变量读取)
COND_CACHE = bool(int(os.environ.get('COND_CACHE', 0)))
UNCOND_CACHE = bool(int(os.environ.get('UNCOND_CACHE', 0)))
from diffusers import DiffusionPipeline
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from qwenimage.pipeline_qwenimage import QwenImagePipeline
import logging
#torch.npu.set_per_process_memory_fraction(0.9)
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def parse_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]
# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith("--"):
if "=" in arg:
key, value = arg.split("=", 1)
key = "--" + key[len("--") :].replace("-", "_")
processed_args.append(f"{key}={value}")
else:
processed_args.append("--" + arg[len("--") :].replace("-", "_"))
else:
processed_args.append(arg)
return super().parse_args(processed_args, namespace)
from qwenimage.distributed.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env
from qwenimage.distributed.parallel_mgr import (
get_sequence_parallel_world_size,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_cfg_group,
init_distributed_environment,
initialize_model_parallel,
get_sequence_parallel_rank,
get_sp_group
)
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import AttentionProcessor
from qwenimage.transformer_qwenimage import (
QwenDoubleStreamAttnProcessor2_0,
QwenImageTransformer2DModel,
QwenEmbedRope,
apply_rotary_emb_qwen,
AdaLayerNorm,
)
from qwenimage.distributed.all_to_all import all_to_all_4D, SeqAllToAll4D
# 1. 并行注意力处理器(继承自Qwen原生处理器)
class xFuserQwenDoubleStreamAttnProcessor(QwenDoubleStreamAttnProcessor2_0):
"""
继承Qwen原生双流注意力处理器,添加USP并行支持
保持所有原生逻辑,仅在注意力计算环节引入并行化
"""
def __init__(self):
super().__init__() # 调用Qwen原生处理器初始化
self.ulysses_pg = get_sp_group().ulysses_group
self.scatter_idx = 2
self.gather_idx = 1
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor, # 图像流 (B, S_img/P, D)
encoder_hidden_states: torch.FloatTensor, # 文本流 (B, S_txt, D)
encoder_hidden_states_mask: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # (img_freqs, txt_freqs)
img_pad_len = None,
txt_pad_len = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
txt_seq_len = encoder_hidden_states.shape[1]
sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
img_original_len = hidden_states.shape[1] * sp_world_size
# 第一次 通信前 padding
if img_pad_len > 0:
img_original_len = hidden_states.shape[1] * sp_world_size - img_pad_len
if sp_rank == sp_world_size - 1:
img_original_len = (hidden_states.shape[1] + img_pad_len) * sp_world_size - img_pad_len
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0, img_pad_len, 0, 0) # (左,右,上,下,前,后),仅填充seq_len维度(dim=1)
)
# -------------------------- 保留Qwen原生QKV计算逻辑 --------------------------
# 图像流QKV
# Compute QKV for image stream (sample projections)
img_query = attn.to_q(hidden_states).unflatten(-1, (attn.heads, -1))
query_sycn = SeqAllToAll4D.apply(
self.ulysses_pg, img_query, self.scatter_idx, self.gather_idx
# [B, S_image/ulysses_size, H, D] --> [B, S_image, H/ulysses_size, D]
)
img_key = attn.to_k(hidden_states).unflatten(-1, (attn.heads, -1))
key_sycn = SeqAllToAll4D.apply(
self.ulysses_pg, img_key, self.scatter_idx, self.gather_idx
)
img_value = attn.to_v(hidden_states).unflatten(-1, (attn.heads, -1))
value_sycn = SeqAllToAll4D.apply(
self.ulysses_pg, img_value, self.scatter_idx, self.gather_idx
)
# 文本流QKV
# Compute QKV for text stream (context projections)
txt_query = attn.add_q_proj(encoder_hidden_states)
txt_key = attn.add_k_proj(encoder_hidden_states)
txt_value = attn.add_v_proj(encoder_hidden_states)
txt_query = txt_query.unflatten(-1, (attn.heads, -1)) # (B, S_txt, H, D_head)
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
img_query = query_sycn()
img_key = key_sycn()
# 第一次 通信完 切除掉 padding
if img_pad_len > 0:
img_query = img_query[:, :img_original_len, :, :].contiguous()
img_key = img_key[:, :img_original_len, :, :].contiguous()
# QK归一化
if attn.norm_q is not None:
img_query = attn.norm_q(img_query)
if attn.norm_k is not None:
img_key = attn.norm_k(img_key)
if attn.norm_added_q is not None:
txt_query = attn.norm_added_q(txt_query)
if attn.norm_added_k is not None:
txt_key = attn.norm_added_k(txt_key)
# Apply RoPE
if image_rotary_emb is not None:
img_freqs, txt_freqs = image_rotary_emb
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
img_value = value_sycn()
# 第一次 通信完 切除掉 padding
if img_pad_len > 0:
img_value = img_value[:, :img_original_len, :, :].contiguous()
sp_world_size = get_sequence_parallel_world_size() # USP
sp_rank = get_sequence_parallel_rank()
txt_query = torch.chunk(txt_query, sp_world_size, dim=2)[sp_rank] # [B, S_text, H, D] --> [B, S_text, H/ulysses_size, D]
txt_key = torch.chunk(txt_key, sp_world_size, dim=2)[sp_rank]
txt_value = torch.chunk(txt_value, sp_world_size, dim=2)[sp_rank]
joint_query = torch.cat([txt_query, img_query], dim=1) # (B, S_txt + S_img, H/ulysses_size, D_head)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
out = attention_forward(
joint_query,
joint_key,
joint_value,
opt_mode="manual",
op_type="fused_attn_score",
layout="BNSD"
)
if type(out) == tuple:
context_layer, _, _ = out
else:
context_layer = out
txt_seq_len = txt_query.shape[1]
text_out = context_layer[:, :txt_seq_len, :, :].contiguous() # 强制连续
image_out = context_layer[:, txt_seq_len:, :, :].contiguous()
# 第二次 通信前 padding
if img_pad_len > 0:
image_out = torch.nn.functional.pad(
image_out,
(0, 0, 0, 0, 0, img_pad_len, 0, 0) # (左,右,上,下,前,后),仅填充seq_len维度(dim=1)
)
img_attn_output = SeqAllToAll4D.apply(
self.ulysses_pg, image_out, self.gather_idx, self.scatter_idx, True
# [B, S_image, H/ulysses_size, D] --> [B, S_image/ulysses_size, H, D]
).flatten(2, 3).to(img_query.dtype)
# 第二次 通信后 切除掉 padding
if img_pad_len > 0:
if sp_rank == sp_world_size - 1:
img_attn_output_len = img_attn_output.shape[1]
img_attn_output = img_attn_output[:, :img_attn_output_len-img_pad_len, :].contiguous()
txt_attn_asyn = get_sp_group().all_gather(text_out, dim=2, async_op=True) # (B, S_txt , H/ulysses_size, D_head) --> (B, S_txt , H, D_head)
# 输出投影(严格保留Qwen的投影逻辑)
img_attn_output = attn.to_out[0](img_attn_output)
if len(attn.to_out) > 1:
img_attn_output = attn.to_out[1](img_attn_output) # 应用dropout
txt_attn_output = txt_attn_asyn().flatten(2, 3).to(img_query.dtype)
txt_attn_output = attn.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
# 2. Transformer并行化函数(适配Qwen-Image结构)
def parallelize_qwen_image_transformer(pipe: DiffusionPipeline):
"""
并行化Qwen-Image的Transformer,添加CFG和USP支持
严格保留Qwen的双流处理逻辑
"""
transformer = pipe.transformer
original_forward = transformer.forward
@functools.wraps(transformer.__class__.forward)
def parallel_forward(
self,
hidden_states: torch.Tensor, # 图像流输入 (B, S_img, D_in)
encoder_hidden_states: torch.Tensor, # 文本流输入 (B, S_txt, D_joint)
encoder_hidden_states_mask: Optional[torch.Tensor] = None,
timestep: torch.LongTensor = None,
img_shapes: Optional[List[Tuple[int, int, int]]] = None, # Qwen专属:(frame, H, W)
txt_seq_lens: Optional[List[int]] = None, # Qwen专属:文本长度列表
guidance: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
return_dict: bool = True,
use_cache: bool = False,
if_cond: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logging.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
# 初始化参数
cfg_world_size = get_classifier_free_guidance_world_size()
cfg_rank = get_classifier_free_guidance_rank()
sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
txt_pad_len = 0
# -------------------------- 保留Qwen原生输入处理 --------------------------
hidden_states = self.img_in(hidden_states) # 图像嵌入
encoder_hidden_states = self.txt_norm(encoder_hidden_states) # 文本归一化
encoder_hidden_states = self.txt_in(encoder_hidden_states) # 文本嵌入
# 时间步嵌入(temb)
timestep = timestep.to(hidden_states.dtype)
temb = self.time_text_embed(timestep, hidden_states) if guidance is None else self.time_text_embed(timestep, guidance, hidden_states)
# 原始文本序列长度
img_seq_len = hidden_states.shape[1] # 6889
# 计算需要填充的长度(使其能被sp_world_size整除)
img_pad_len = (sp_world_size - (img_seq_len % sp_world_size)) % sp_world_size
if sp_world_size > 1:
hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank]
# RoPE参数生成与切分
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
img_freqs, txt_freqs = image_rotary_emb
# 切分RoPE参数(与序列切分同步)
# if sp_world_size > 1:
# img_freqs = torch.chunk(img_freqs, sp_world_size, dim=0)[sp_rank] # 图像RoPE切分
image_rotary_emb = (img_freqs, txt_freqs)
# -------------------------- Transformer块前向计算 --------------------------
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
encoder_hidden_states_mask,
temb,
image_rotary_emb,
)
else:
if not use_cache:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
img_pad_len=img_pad_len,
txt_pad_len=txt_pad_len
)
else:
if if_cond:
hidden_states, encoder_hidden_states = self.cache_cond.apply(
block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
img_pad_len=img_pad_len,
txt_pad_len=txt_pad_len
)
else:
hidden_states, encoder_hidden_states = self.cache_uncond.apply(
block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
img_pad_len=img_pad_len,
txt_pad_len=txt_pad_len
)
#-----------------------------------
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
# Use only the image part (hidden_states) from the dual-stream blocks
# -------------------------- 输出处理(对齐Qwen原生) --------------------------
hidden_states = self.norm_out(hidden_states, temb) # AdaLayerNorm调制
output = self.proj_out(hidden_states) # 投影输出
# 第一次 通信前 padding
if img_pad_len > 0:
output_original_len = output.shape[1] * sp_world_size - img_pad_len
if sp_rank == sp_world_size - 1:
output_original_len = (output.shape[1] + img_pad_len) * sp_world_size - img_pad_len
output = torch.nn.functional.pad(
output,
(0, 0, 0, img_pad_len, 0, 0) # (左,右,上,下,前,后),仅填充seq_len维度(dim=1)
)
# 将图像序列在维度1拼接起来
# hidden_states = self._cat_sequence(hidden_states, dim=1)
output = output if sp_world_size <= 1 else get_sp_group().all_gather(output, dim=1)
# 第一次 通信完 切除掉 padding
if img_pad_len > 0:
output = output[:, :output_original_len, :].contiguous()
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
# -------------------------- 返回结果 --------------------------
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
parallel_forward = parallel_forward.__get__(transformer)
# 替换forward方法
transformer.forward = parallel_forward
#-------------------------- 替换注意力处理器为并行版本 --------------------------
for block in transformer.transformer_blocks:
block.attn.processor = xFuserQwenDoubleStreamAttnProcessor()
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
# -------------------------- Flask Web 服务初始化 --------------------------
app = Flask(__name__)
# 线程安全的队列,负责在 Flask Web 服务和主推断线程中通信
request_queue = queue.Queue()
response_queue = queue.Queue()
@app.route('/generate', methods=['POST'])
def generate_api():
"""接收在线请求"""
try:
data = request.json or {}
prompt = data.get("prompt", "")
negative_prompt = data.get("negative_prompt", "")
width = int(data.get("width", 1024))
height = int(data.get("height", 1024))
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
# 最大值校验
if width > 2512 or height > 2512:
return jsonify({"error": "Width and height must be <= 2512"}), 400
# 为了避免 Transformer/VAE 报错,确保宽高是 16 的倍数
width = (width // 16) * 16
height = (height // 16) * 16
# 将请求推入队列通知 Rank 0 的主推理线程
request_queue.put({
"prompt": prompt,
"negative_prompt": negative_prompt,
"width": width,
"height": height
})
# 阻塞等待推理结果 (BytesIO对象)
result = response_queue.get()
if isinstance(result, Exception):
return jsonify({"error": str(result)}), 500
# 直接将流媒体数据返回,不进行 base64 处理,浏览器和 curl 可以直接展示/保存
return send_file(result, mimetype='image/png')
except Exception as e:
return jsonify({"error": str(e)}), 500
def run_flask_app():
# 启动端口和服务IP
app.run(host="10.119.1.59", port=8080, threaded=True, use_reloader=False)
def main():
# 用于消除推理时“RuntimeError: invalid npu option name: allow_internal_format”告警,开启NPU私有格式
if hasattr(torch_npu.npu, 'config'):
torch_npu.npu.config.allow_internal_format = True
else:
# 如果上面的不行,再尝试这个更底层的开关
torch.npu.set_compile_mode("jit_compile=True")
# 解析参数
parser = argparse.ArgumentParser(description="Qwen-Image CFG+USP Parallel Inference Service")
parser.add_argument("--prompt_lang", type=str, default="en", choices=["en", "zh"],
help="Language for positive magic prompt. 'en' for English, 'zh' for Chinese. ")
# 分布式与模型配置参数(保持不变)
parser.add_argument("--batch-size", type=int, default=1, help="批次大小")
parser.add_argument("--cfg_size", type=int, default=1, help="CFG并行大小")
parser.add_argument("--ulysses_size", type=int, default=1, help="Ulysses并行大小")
parser.add_argument("--ring_size", type=int, default=1, help="Ring注意力并行大小")
parser.add_argument("--tp_size", type=int, default=1, help="张量并行大小")
parser.add_argument("--vae_parallel", action="store_true", default=False, help="VAE并行")
parser.add_argument("--t5_fsdp", action="store_true", default=False, help="T5使用FSDP")
parser.add_argument("--t5_cpu", action="store_true", default=False, help="T5放CPU")
parser.add_argument("--dit_fsdp", action="store_true", default=False, help="DiT使用FSDP")
# 模型与设备配置(保持不变)
parser.add_argument("--model_path", type=str, default="/home/weight/Qwen-Image-Edit-2509/", help="模型本地路径")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"], help="数据类型")
parser.add_argument("--device", type=str, default="npu", help="运行设备(npu/cuda/cpu)")
parser.add_argument("--device_id", type=int, default=0, help="设备ID")
# 推理与输出配置
parser.add_argument("--num_inference_steps", type=int, default=50, help="推理步数")
parser.add_argument("--true_cfg_scale", type=float, default=4.0, help="真实CFG缩放系数")
parser.add_argument("--guidance_scale", type=float, default=1.0, help="引导缩放系数")
parser.add_argument("--seed", type=int, default=0, help="随机种子")
parser.add_argument(
"--quant_desc_path",
type=str,
default=None,
help="Path to quantization description file (e.g., quant_model_description_*.json). "
"Enables quantization if provided (applies to Text Encoder and Transformer)."
)
args = parser.parse_args()
if args.quant_desc_path:
if not os.path.exists(args.quant_desc_path):
raise FileNotFoundError(f"Quantization description file not found: {args.quant_desc_path}")
if not args.quant_desc_path.endswith(".json") or "quant_model_description" not in args.quant_desc_path:
raise ValueError(f"Invalid quantization file: {args.quant_desc_path}. "
"Expected format: 'quant_model_description_*.json'")
# 分布式初始化
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
_init_logging(rank)
if world_size > 1:
device = f"{args.device}:{local_rank}"
torch.npu.set_device(local_rank)
dist.init_process_group(
backend="hccl",
init_method="env://",
rank=rank,
world_size=world_size)
else:
assert not (
args.t5_fsdp or args.dit_fsdp
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
assert not (
args.cfg_size > 1 or args.ulysses_size > 1 or args.ring_size > 1
), f"context parallel are not supported in non-distributed environments."
assert not (
args.vae_parallel
), f"vae parallel are not supported in non-distributed environments."
device = f"{args.device}:{args.device_id}"
torch.npu.set_device(args.device_id)
positive_magic = {
"en": "Ultra HD, 4K, cinematic composition.",
"zh": "超清,4K,电影级构图"
}
# 数据类型配置
torch_dtype = torch.bfloat16 if args.torch_dtype == "bfloat16" else torch.float32
# 加载模型
if rank == 0:
print(f"从 {args.model_path} 加载模型...")
transformer = QwenImageTransformer2DModel.from_pretrained(
os.path.join(args.model_path, 'transformer'),
torch_dtype=torch_dtype,
device_map=None, # 禁用自动设备映射,昇腾环境下默认加载到CPU
low_cpu_mem_usage=True, # 启用CPU低内存模式,避免加载时CPU内存溢出)
)
if args.quant_desc_path:
from mindiesd import quantize
if rank == 0:
print("Quantizing Transformer on NPU (单独量化核心组件)...")
quantize(
model=transformer,
quant_des_path=args.quant_desc_path,
use_nz=True,
)
torch.npu.empty_cache() # 清理NPU显存缓存
gc.collect() # 触发Python垃圾回收
pipeline = QwenImagePipeline.from_pretrained(
args.model_path,
transformer = transformer,
torch_dtype=torch_dtype,
device_map=None, # 禁用自动设备映射(昇腾环境推荐)
low_cpu_mem_usage=True, # 低CPU内存模式
)
# VAE优化配置(避免显存溢出)
pipeline.vae.use_slicing = True
pipeline.vae.use_tiling = True
# 移动模型到目标设备
pipeline.to(device)
pipeline.set_progress_bar_config(disable=None) # 显示进度条
# 缓存配置(如果启用)
if COND_CACHE or UNCOND_CACHE:
# 保守 cache
cache_config = CacheConfig(
method="dit_block_cache",
blocks_count=60,
steps_count=args.num_inference_steps,
step_start=10,
step_interval=3,
step_end=35,
block_start=10,
block_end=50
)
pipeline.transformer.cache_cond = CacheAgent(cache_config) if COND_CACHE else None
pipeline.transformer.cache_uncond = CacheAgent(cache_config) if UNCOND_CACHE else None
if rank == 0:
print("启用缓存配置")
if args.cfg_size > 1 or args.ulysses_size > 1 or args.ring_size > 1 or args.tp_size > 1:
assert args.cfg_size * args.ulysses_size * args.ring_size * args.tp_size == world_size, f"The number of cfg_size, ulysses_size and ring_size should be equal to the world size."
sp_degree = args.ulysses_size * args.ring_size
parallel_config = ParallelConfig(
sp_degree=sp_degree,
ulysses_degree=args.ulysses_size,
ring_degree=args.ring_size,
tp_degree=args.tp_size,
use_cfg_parallel=(args.cfg_size==2),
world_size=world_size,
)
init_parallel_env(parallel_config)
if args.ulysses_size > 1:
# 3. 并行化改造Transformer
parallelize_qwen_image_transformer(pipeline)
# 5. 启动 Flask 服务(仅 Rank 0 运行 Web 容器)
if rank == 0:
logging.info("Starting Flask Online API Server on 10.119.1.59:8080...")
flask_thread = threading.Thread(target=run_flask_app, daemon=True)
flask_thread.start()
# 确保所有分布式 Rank 都准备完毕
torch.distributed.barrier()
# 6. 常驻显存推理轮询机制 (所有 GPU 共进退)
while True:
# 定义广播容器
sync_data = [None]
# 仅 Rank 0 去读取队列中的请求
if rank == 0:
try:
# 阻塞直到接收到 API 请求
req = request_queue.get()
sync_data = [req]
except Exception as e:
pass
# 通过 HCCL 广播给所有参与计算的卡,确保8张卡获得完全相同的入参触发计算
dist.broadcast_object_list(sync_data, src=0)
# 提取参数
request_params = sync_data[0]
if not request_params:
continue
lang = args.prompt_lang
raw_prompt = request_params["prompt"]
full_prompt = raw_prompt + " " + positive_magic.get(lang, "")
negative_prompt = request_params["negative_prompt"]
width = request_params["width"]
height = request_params["height"]
# 推理前清空环境
torch.npu.empty_cache()
gc.collect()
if rank == 0:
logging.info(f"Start Parallel Inference... Size: {width}x{height}, Prompt: {raw_prompt}")
start_time = time.time()
try:
# 执行基于并行环境的 8 卡联合推理
image = pipeline(
prompt=full_prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=args.num_inference_steps,
true_cfg_scale=args.true_cfg_scale,
generator=torch.Generator(device=device).manual_seed(args.seed),
).images[0]
torch.npu.synchronize()
# 推理成功后,由于 Diffusers 的 Pipeline 输出最终都在所有卡汇聚好,由 Rank 0 来处理返回即可
if rank == 0:
end_time = time.time()
logging.info(f"Inference Completed. Time: {end_time - start_time:.2f} seconds")
# 直接转换图片为 IO 二进制流
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
# 塞进队列,Flask 拿到后返回给用户
response_queue.put(img_byte_arr)
except Exception as e:
if rank == 0:
logging.error(f"Inference error: {str(e)}")
response_queue.put(e)
finally:
# 强制释放当前请求所占用的所有变量
if 'image' in locals():
del image
del request_params, full_prompt, negative_prompt
torch.npu.empty_cache()
gc.collect()
if __name__ == "__main__":
main()运行 run.sh 脚本,需检查权重路径是否正确,然后启动服务(注意此处的 ASCEND_RT_VISIBLE_DEVICES 为需使用的 NPU 卡号,如下配置使用 8 张卡)
#!/bin/bash
export model_path="/weight/Qwen-Image-2512"
# 显存管理优化:限制最大切分尺寸,防止产生无法利用的显存碎片
export PYTORCH_NPU_ALLOC_CONF=max_split_size_mb:512
export HCCL_EXEC_TIMEOUT=0 # 设置为0表示无限等待,不受1800s限制
export LCCL_DETERMINISTIC=true
export HCCL_DETERMINISTIC=true
export ATB_MATMUL_SHUFFLE_K_ENABLE=0
export ATB_LLM_LCOC_ENABLE=true
export CLOSE_MATMUL_K_SHIFT=true
# 算子优化
export ROPE_FUSE=1
export ADALN_FUSE=1
# 算法优化
export COND_CACHE=1
export UNCOND_CACHE=1
# 8卡 cfg=2 ulysses=4
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
torchrun --nproc_per_node=8 --master-port 29508 run_cfg_usp-7.py \
--model_path ${model_path} \
--prompt_lang "en" \
--num_inference_steps 50 \
--seed 42 \
--ulysses_size 8 \
--cfg_size 18卡并行推理服务启动后会看到类似如下的打印:
[root@ww-qwen-image-2512-a3-mindie-worker-0 Qwen-Image]# ./run_cfg_usp-7.sh
[2026-03-05 14:47:05,545] torch.distributed.run: [WARNING]
[2026-03-05 14:47:05,545] torch.distributed.run: [WARNING] *****************************************
[2026-03-05 14:47:05,545] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
[2026-03-05 14:47:05,545] torch.distributed.run: [WARNING] *****************************************
从 /weight/Qwen-Image-2512 加载模型...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 30.76it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 29.93it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 29.06it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 27.54it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 29.72it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 28.25it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 27.62it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 26.13it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 22.63it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 22.37it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 20.98it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 21.57it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 7.28it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 21.38it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 7.47it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 7.74it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 7.31it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 6.88it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 20.90it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 22.40it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.88it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 22.49it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.78it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.95it/s]
启用缓存配置
[2026-03-05 14:47:21,439] WARNING: Model parallel is not initialized, initializing...
[2026-03-05 14:47:21,444] INFO: Starting Flask Online API Server on 10.119.1.59:8080...
* Serving Flask app 'run_cfg_usp-7'
* Debug mode: off
[2026-03-05 14:47:21,448] INFO: WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.
* Running on http://10.119.1.59:8080
[2026-03-05 14:47:21,448] INFO: Press CTRL+C to quit安装图片生成测试工具
yum install curl jq
使用ctrl命令进行文生图测试: 可参考如下命令,输入prompt,输出 output.png 图片。
curl -X POST http://10.119.1.59:8080/generate \
-H "Content-Type: application/json" \
-d '{
"prompt": "A futuristic city under a neon sky.",
"width": 1024,
"height": 1024
}' \
-o output_image.png