Ascend-SACT/Qwen-Image-2512_mindIE
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

A2 mindIE+Qwen-Image-2512模型8卡在线推理

一、环境配置

环境配置配置说明
硬件配置Atlas A2 910B
CANN版本8.3rc1
推理框架mindIE2.3
部署方式8卡 部署

二、权重下载

安装下载工具 pip install modelscope 下载模型权重 modelscope download --model Qwen/Qwen-Image-2512 --local_dir /weight/Qwen-Image-2512

三、镜像文件

mindIE镜像文件(下载mindIE官方镜像) mindie_2.3.0-py311-openeuler24.03:A2.0.2

四、python相关代码及依赖包安装

1. 安装依赖包
  pip install diffusers==0.35.1
  pip install transformers==4.52.4
  pip install yunchang==0.6.0
2. 下载代码
git clone https://modelers.cn/MindIE/Qwen-Image.git && cd Qwen-Image
3. 环境配置
  1. 用 Python 获取 diffusers 的安装目录
DIFFUSERS_PATH=$(python -c "import diffusers; import os; print(os.path.dirname(diffusers.__file__))")
  1. 替换pipeline_qwenimage文件 文生图场景
  cp -r pipeline_qwenimage.py "$DIFFUSERS_PATH/pipelines/qwenimage/pipeline_qwenimage.py"   
  1. 替换transformer_qwenimage文件
cp -r transformer_qwenimage.py "$DIFFUSERS_PATH/models/transformers/transformer_qwenimage.py"

五、服务部署

使用910B上的8卡进行在线推理服务部署。

如果需要开启在线推理服务,环境配置同第3章所示,需要将run.sh 所调用的 run.py文件做一定的适配修改。

1)安装flask框架
pip install flask
2)构建flask服务脚本

将原 run_cfg_usp.py 的离线推理代码修改为如下使用flask框架的在线推理服务,避免每次推理都要重新加载权重,原run_cfg_usp.py 的内容请做好备份。 注:服务IP和端口号在run_flask_app()函数中设定,代码第511行,注意修改为实际环境的IP和端口号;

import sys
import os
import json
import argparse
import functools
import time
import torch
import torch_npu
import threading
import queue
import io
from flask import Flask, request, send_file, jsonify

#-------------------解决 diffuser 0.35.1 torch2.1 报错----------------
def custom_op(
    name,
    fn=None,
    /,
    *,
    mutates_args,
    device_types=None,
    schema=None,
    tags=None,
):
    def decorator(func):
        return func
    
    if fn is not None:
        return decorator(fn)
    
    return decorator

def register_fake(
    op,
    fn=None,
    /,
    *,
    lib=None,
    _stacklevel: int = 1,
    allow_override: bool = False,
):
    def decorator(func):
        return func
    
    if fn is not None:
        return decorator(fn)
    
    return decorator
    
torch.library.custom_op = custom_op
torch.library.register_fake = register_fake
#-----------------------------------------------------------------------

import torch.distributed as dist
from typing import Optional, Tuple, Union, List, Dict, Any
import numpy as np
import gc  # 新增:垃圾回收模块

from PIL import Image
from mindiesd import CacheConfig, CacheAgent
from mindiesd import attention_forward

# 缓存配置开关(从环境变量读取)
COND_CACHE = bool(int(os.environ.get('COND_CACHE', 0)))
UNCOND_CACHE = bool(int(os.environ.get('UNCOND_CACHE', 0)))

from diffusers import DiffusionPipeline
from diffusers.models.modeling_outputs import Transformer2DModelOutput

from qwenimage.pipeline_qwenimage import QwenImagePipeline

import logging

#torch.npu.set_per_process_memory_fraction(0.9)

class FlexibleArgumentParser(argparse.ArgumentParser):
    """ArgumentParser that allows both underscore and dash in names."""

    def parse_args(self, args=None, namespace=None):
        if args is None:
            args = sys.argv[1:]

        # Convert underscores to dashes and vice versa in argument names
        processed_args = []
        for arg in args:
            if arg.startswith("--"):
                if "=" in arg:
                    key, value = arg.split("=", 1)
                    key = "--" + key[len("--") :].replace("-", "_")
                    processed_args.append(f"{key}={value}")
                else:
                    processed_args.append("--" + arg[len("--") :].replace("-", "_"))
            else:
                processed_args.append(arg)

        return super().parse_args(processed_args, namespace)

from qwenimage.distributed.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env

from qwenimage.distributed.parallel_mgr import (
    get_sequence_parallel_world_size,
    get_classifier_free_guidance_world_size,
    get_classifier_free_guidance_rank,
    get_cfg_group,
    init_distributed_environment,
    initialize_model_parallel,
    get_sequence_parallel_rank,
    get_sp_group
)

from diffusers.models.attention import Attention
from diffusers.models.attention_processor import AttentionProcessor
from qwenimage.transformer_qwenimage import (  
    QwenDoubleStreamAttnProcessor2_0,
    QwenImageTransformer2DModel,
    QwenEmbedRope,
    apply_rotary_emb_qwen,
    AdaLayerNorm,
)

from qwenimage.distributed.all_to_all import all_to_all_4D, SeqAllToAll4D

# 1. 并行注意力处理器(继承自Qwen原生处理器)
class xFuserQwenDoubleStreamAttnProcessor(QwenDoubleStreamAttnProcessor2_0):
    """
    继承Qwen原生双流注意力处理器,添加USP并行支持
    保持所有原生逻辑,仅在注意力计算环节引入并行化
    """
    def __init__(self):
        super().__init__()  # 调用Qwen原生处理器初始化

        self.ulysses_pg = get_sp_group().ulysses_group
        self.scatter_idx = 2
        self.gather_idx = 1

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,  # 图像流 (B, S_img/P, D)
        encoder_hidden_states: torch.FloatTensor,  # 文本流 (B, S_txt, D)
        encoder_hidden_states_mask: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # (img_freqs, txt_freqs)
        img_pad_len = None,
        txt_pad_len = None
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:

        txt_seq_len = encoder_hidden_states.shape[1]   
        
        sp_world_size = get_sequence_parallel_world_size()
        sp_rank = get_sequence_parallel_rank()
        img_original_len = hidden_states.shape[1]  * sp_world_size
        # 第一次 通信前 padding
        if img_pad_len > 0:
            img_original_len = hidden_states.shape[1]  * sp_world_size - img_pad_len
            if sp_rank == sp_world_size - 1:
            
                img_original_len = (hidden_states.shape[1] + img_pad_len) * sp_world_size - img_pad_len
                hidden_states = torch.nn.functional.pad(
                    hidden_states, 
                    (0, 0, 0, img_pad_len, 0, 0)  # (左,右,上,下,前,后),仅填充seq_len维度(dim=1)
                )
 
        # -------------------------- 保留Qwen原生QKV计算逻辑 --------------------------
        # 图像流QKV
        # Compute QKV for image stream (sample projections)
        img_query = attn.to_q(hidden_states).unflatten(-1, (attn.heads, -1))
        query_sycn = SeqAllToAll4D.apply(
            self.ulysses_pg, img_query, self.scatter_idx, self.gather_idx
            # [B, S_image/ulysses_size, H, D] -->   [B, S_image, H/ulysses_size, D]
        )
        img_key = attn.to_k(hidden_states).unflatten(-1, (attn.heads, -1))
        key_sycn = SeqAllToAll4D.apply(
            self.ulysses_pg, img_key, self.scatter_idx, self.gather_idx
        )
        img_value = attn.to_v(hidden_states).unflatten(-1, (attn.heads, -1))
        value_sycn = SeqAllToAll4D.apply(
            self.ulysses_pg, img_value, self.scatter_idx, self.gather_idx
        )
        # 文本流QKV
        # Compute QKV for text stream (context projections)
        txt_query = attn.add_q_proj(encoder_hidden_states)
        txt_key = attn.add_k_proj(encoder_hidden_states)
        txt_value = attn.add_v_proj(encoder_hidden_states)

        txt_query = txt_query.unflatten(-1, (attn.heads, -1))  # (B, S_txt, H, D_head)
        txt_key = txt_key.unflatten(-1, (attn.heads, -1))
        txt_value = txt_value.unflatten(-1, (attn.heads, -1))

        img_query = query_sycn()
        img_key = key_sycn()
        # 第一次 通信完 切除掉 padding
        if img_pad_len > 0:
            img_query = img_query[:, :img_original_len, :, :].contiguous()
            img_key = img_key[:, :img_original_len, :, :].contiguous()

        # QK归一化
        if attn.norm_q is not None:
            img_query = attn.norm_q(img_query)
        if attn.norm_k is not None:
            img_key = attn.norm_k(img_key)
        if attn.norm_added_q is not None:
            txt_query = attn.norm_added_q(txt_query)
        if attn.norm_added_k is not None:
            txt_key = attn.norm_added_k(txt_key)

        # Apply RoPE
        if image_rotary_emb is not None:
            img_freqs, txt_freqs = image_rotary_emb
            img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)        
            img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
            txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
            txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)

        img_value = value_sycn()

         # 第一次 通信完 切除掉 padding
        if img_pad_len > 0:
            img_value = img_value[:, :img_original_len, :, :].contiguous()

        sp_world_size = get_sequence_parallel_world_size()  # USP
        sp_rank = get_sequence_parallel_rank()

        txt_query = torch.chunk(txt_query, sp_world_size, dim=2)[sp_rank]  # [B, S_text, H, D] -->  [B, S_text, H/ulysses_size, D]
        txt_key = torch.chunk(txt_key, sp_world_size, dim=2)[sp_rank]
        txt_value = torch.chunk(txt_value, sp_world_size, dim=2)[sp_rank]

        joint_query = torch.cat([txt_query, img_query], dim=1)  # (B, S_txt  + S_img, H/ulysses_size, D_head)
        joint_key = torch.cat([txt_key, img_key], dim=1)
        joint_value = torch.cat([txt_value, img_value], dim=1)
        out = attention_forward(
            joint_query,
            joint_key,
            joint_value,
            opt_mode="manual",
            op_type="fused_attn_score",
            layout="BNSD"
        )

        if type(out) == tuple:
            context_layer, _, _ = out
        else:
            context_layer = out

        txt_seq_len = txt_query.shape[1]

        text_out = context_layer[:, :txt_seq_len, :, :].contiguous()  # 强制连续
        image_out = context_layer[:, txt_seq_len:, :, :].contiguous()

        # 第二次 通信前 padding
        if img_pad_len > 0:
            image_out = torch.nn.functional.pad(
                image_out, 
                (0, 0, 0, 0, 0, img_pad_len, 0, 0)  # (左,右,上,下,前,后),仅填充seq_len维度(dim=1)
            )

        img_attn_output = SeqAllToAll4D.apply(
            self.ulysses_pg, image_out, self.gather_idx, self.scatter_idx, True
            # [B,  S_image, H/ulysses_size, D] -->  [B, S_image/ulysses_size, H, D]
        ).flatten(2, 3).to(img_query.dtype)

        # 第二次 通信后 切除掉 padding
        if img_pad_len > 0:
            if sp_rank == sp_world_size - 1:
                img_attn_output_len = img_attn_output.shape[1]
                img_attn_output = img_attn_output[:, :img_attn_output_len-img_pad_len, :].contiguous()

        txt_attn_asyn = get_sp_group().all_gather(text_out, dim=2, async_op=True)  # (B, S_txt  , H/ulysses_size, D_head) -->   (B, S_txt  , H, D_head)

        # 输出投影(严格保留Qwen的投影逻辑)
        img_attn_output = attn.to_out[0](img_attn_output)
        if len(attn.to_out) > 1:
            img_attn_output = attn.to_out[1](img_attn_output)  # 应用dropout
        txt_attn_output = txt_attn_asyn().flatten(2, 3).to(img_query.dtype)
        txt_attn_output = attn.to_add_out(txt_attn_output)

        return img_attn_output, txt_attn_output

# 2. Transformer并行化函数(适配Qwen-Image结构)
def parallelize_qwen_image_transformer(pipe: DiffusionPipeline):
    """
    并行化Qwen-Image的Transformer,添加CFG和USP支持
    严格保留Qwen的双流处理逻辑
    """
    transformer = pipe.transformer
    original_forward = transformer.forward

    @functools.wraps(transformer.__class__.forward)
    def parallel_forward(
        self,
        hidden_states: torch.Tensor,  # 图像流输入 (B, S_img, D_in)
        encoder_hidden_states: torch.Tensor,  # 文本流输入 (B, S_txt, D_joint)
        encoder_hidden_states_mask: Optional[torch.Tensor] = None,
        timestep: torch.LongTensor = None,
        img_shapes: Optional[List[Tuple[int, int, int]]] = None,  # Qwen专属:(frame, H, W)
        txt_seq_lens: Optional[List[int]] = None,  # Qwen专属:文本长度列表
        guidance: Optional[torch.Tensor] = None,
        attention_kwargs: Optional[Dict[str, Any]] = None,
        controlnet_block_samples=None,
        return_dict: bool = True,
        use_cache: bool = False,     
        if_cond: bool = True,    
    ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
        if attention_kwargs is not None:
                attention_kwargs = attention_kwargs.copy()
                lora_scale = attention_kwargs.pop("scale", 1.0)
        else:
            lora_scale = 1.0

        from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
        if USE_PEFT_BACKEND:
            # weight the lora layers by setting `lora_scale` for each PEFT layer
            scale_lora_layers(self, lora_scale)
        else:
            if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
                logging.warning(
                    "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
                )

        # 初始化参数
        cfg_world_size = get_classifier_free_guidance_world_size()
        cfg_rank = get_classifier_free_guidance_rank()
        sp_world_size = get_sequence_parallel_world_size()
        sp_rank = get_sequence_parallel_rank()

        txt_pad_len = 0

        # -------------------------- 保留Qwen原生输入处理 --------------------------
        hidden_states = self.img_in(hidden_states)  # 图像嵌入
        encoder_hidden_states = self.txt_norm(encoder_hidden_states)  # 文本归一化
        encoder_hidden_states = self.txt_in(encoder_hidden_states)  # 文本嵌入

        # 时间步嵌入(temb)
        timestep = timestep.to(hidden_states.dtype)
        temb = self.time_text_embed(timestep, hidden_states) if guidance is None else self.time_text_embed(timestep, guidance, hidden_states)

        # 原始文本序列长度
        img_seq_len = hidden_states.shape[1]   # 6889
        # 计算需要填充的长度(使其能被sp_world_size整除)
        img_pad_len = (sp_world_size - (img_seq_len % sp_world_size)) % sp_world_size

        if sp_world_size > 1:
            hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank]

        #  RoPE参数生成与切分
        image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)  
        img_freqs, txt_freqs = image_rotary_emb
        # 切分RoPE参数(与序列切分同步)
        # if sp_world_size > 1:
        #    img_freqs = torch.chunk(img_freqs, sp_world_size, dim=0)[sp_rank]  # 图像RoPE切分
        image_rotary_emb = (img_freqs, txt_freqs)

        # -------------------------- Transformer块前向计算 --------------------------
        for index_block, block in enumerate(self.transformer_blocks):
            if torch.is_grad_enabled() and self.gradient_checkpointing:
                encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
                    block,
                    hidden_states,
                    encoder_hidden_states,
                    encoder_hidden_states_mask,
                    temb,
                    image_rotary_emb,
                )

            else:
                if not use_cache:  
                    hidden_states, encoder_hidden_states = block(
                        hidden_states=hidden_states,
                        encoder_hidden_states=encoder_hidden_states,
                        encoder_hidden_states_mask=encoder_hidden_states_mask,
                        temb=temb,
                        image_rotary_emb=image_rotary_emb,
                        joint_attention_kwargs=attention_kwargs,
                        img_pad_len=img_pad_len,
                        txt_pad_len=txt_pad_len
                    )
                else:
                    if if_cond:
                        hidden_states, encoder_hidden_states = self.cache_cond.apply(
                            block,
                            hidden_states=hidden_states,
                            encoder_hidden_states=encoder_hidden_states,
                            encoder_hidden_states_mask=encoder_hidden_states_mask,
                            temb=temb,
                            image_rotary_emb=image_rotary_emb,
                            joint_attention_kwargs=attention_kwargs,
                            img_pad_len=img_pad_len,
                            txt_pad_len=txt_pad_len
                        )
                    else:
                        hidden_states, encoder_hidden_states = self.cache_uncond.apply(
                            block,
                            hidden_states=hidden_states,
                            encoder_hidden_states=encoder_hidden_states,
                            encoder_hidden_states_mask=encoder_hidden_states_mask,
                            temb=temb,
                            image_rotary_emb=image_rotary_emb,
                            joint_attention_kwargs=attention_kwargs,
                            img_pad_len=img_pad_len,
                            txt_pad_len=txt_pad_len
                        )
                #-----------------------------------

            # controlnet residual
            if controlnet_block_samples is not None:
                interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
                interval_control = int(np.ceil(interval_control))
                hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
       
        # Use only the image part (hidden_states) from the dual-stream blocks
        # -------------------------- 输出处理(对齐Qwen原生) --------------------------
        hidden_states = self.norm_out(hidden_states, temb)  # AdaLayerNorm调制
        output = self.proj_out(hidden_states)  # 投影输出

        # 第一次 通信前 padding
        if img_pad_len > 0:  
            output_original_len = output.shape[1]  * sp_world_size - img_pad_len
            if sp_rank == sp_world_size - 1:

                output_original_len = (output.shape[1] + img_pad_len) * sp_world_size - img_pad_len

                output = torch.nn.functional.pad(
                    output, 
                    (0, 0, 0, img_pad_len, 0, 0)  # (左,右,上,下,前,后),仅填充seq_len维度(dim=1)
                )
        # 将图像序列在维度1拼接起来
        # hidden_states = self._cat_sequence(hidden_states, dim=1)
        output = output if sp_world_size <= 1 else get_sp_group().all_gather(output, dim=1)

        # 第一次 通信完 切除掉 padding
        if img_pad_len > 0:
            output = output[:, :output_original_len, :].contiguous()

        if USE_PEFT_BACKEND:
            # remove `lora_scale` from each PEFT layer
            unscale_lora_layers(self, lora_scale)

        # -------------------------- 返回结果 --------------------------
        if not return_dict:
            return (output,)
        return Transformer2DModelOutput(sample=output)

    parallel_forward = parallel_forward.__get__(transformer)
    # 替换forward方法
    transformer.forward = parallel_forward

    #-------------------------- 替换注意力处理器为并行版本 --------------------------
    for block in transformer.transformer_blocks:
        block.attn.processor = xFuserQwenDoubleStreamAttnProcessor()

def _init_logging(rank):
    # logging
    if rank == 0:
        # set format
        logging.basicConfig(
            level=logging.INFO,
            format="[%(asctime)s] %(levelname)s: %(message)s",
            handlers=[logging.StreamHandler(stream=sys.stdout)])
    else:
        logging.basicConfig(level=logging.ERROR)

# -------------------------- Flask Web 服务初始化 --------------------------
app = Flask(__name__)
# 线程安全的队列,负责在 Flask Web 服务和主推断线程中通信
request_queue = queue.Queue()
response_queue = queue.Queue()

@app.route('/generate', methods=['POST'])
def generate_api():
    """接收在线请求"""
    try:
        data = request.json or {}
        prompt = data.get("prompt", "")
        negative_prompt = data.get("negative_prompt", "")
        width = int(data.get("width", 1024))
        height = int(data.get("height", 1024))

        if not prompt:
            return jsonify({"error": "Prompt is required"}), 400

        # 最大值校验
        if width > 2512 or height > 2512:
            return jsonify({"error": "Width and height must be <= 2512"}), 400

        # 为了避免 Transformer/VAE 报错,确保宽高是 16 的倍数
        width = (width // 16) * 16
        height = (height // 16) * 16

        # 将请求推入队列通知 Rank 0 的主推理线程
        request_queue.put({
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "width": width,
            "height": height
        })

        # 阻塞等待推理结果 (BytesIO对象)
        result = response_queue.get()

        if isinstance(result, Exception):
            return jsonify({"error": str(result)}), 500

        # 直接将流媒体数据返回,不进行 base64 处理,浏览器和 curl 可以直接展示/保存
        return send_file(result, mimetype='image/png')
    
    except Exception as e:
        return jsonify({"error": str(e)}), 500

def run_flask_app():
    # 启动端口和服务IP
    app.run(host="10.119.1.59", port=8080, threaded=True, use_reloader=False)

def main():
    # 用于消除推理时“RuntimeError: invalid npu option name: allow_internal_format”告警,开启NPU私有格式
    if hasattr(torch_npu.npu, 'config'):
        torch_npu.npu.config.allow_internal_format = True
    else:
        # 如果上面的不行,再尝试这个更底层的开关
        torch.npu.set_compile_mode("jit_compile=True")

    # 解析参数
    parser = argparse.ArgumentParser(description="Qwen-Image CFG+USP Parallel Inference Service")

    parser.add_argument("--prompt_lang", type=str, default="en", choices=["en", "zh"],
                    help="Language for positive magic prompt. 'en' for English, 'zh' for Chinese. ")

    # 分布式与模型配置参数(保持不变)
    parser.add_argument("--batch-size", type=int, default=1, help="批次大小")
    parser.add_argument("--cfg_size", type=int, default=1, help="CFG并行大小")
    parser.add_argument("--ulysses_size", type=int, default=1, help="Ulysses并行大小")
    parser.add_argument("--ring_size", type=int, default=1, help="Ring注意力并行大小")
    parser.add_argument("--tp_size", type=int, default=1, help="张量并行大小")
    parser.add_argument("--vae_parallel", action="store_true", default=False, help="VAE并行")
    parser.add_argument("--t5_fsdp", action="store_true", default=False, help="T5使用FSDP")
    parser.add_argument("--t5_cpu", action="store_true", default=False, help="T5放CPU")
    parser.add_argument("--dit_fsdp", action="store_true", default=False, help="DiT使用FSDP")

    # 模型与设备配置(保持不变)
    parser.add_argument("--model_path", type=str, default="/home/weight/Qwen-Image-Edit-2509/", help="模型本地路径")
    parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"], help="数据类型")
    parser.add_argument("--device", type=str, default="npu", help="运行设备(npu/cuda/cpu)")
    parser.add_argument("--device_id", type=int, default=0, help="设备ID")
    
    # 推理与输出配置
    parser.add_argument("--num_inference_steps", type=int, default=50, help="推理步数")
    parser.add_argument("--true_cfg_scale", type=float, default=4.0, help="真实CFG缩放系数")
    parser.add_argument("--guidance_scale", type=float, default=1.0, help="引导缩放系数")
    parser.add_argument("--seed", type=int, default=0, help="随机种子")

    parser.add_argument(
        "--quant_desc_path",
        type=str,
        default=None,
        help="Path to quantization description file (e.g., quant_model_description_*.json). "
             "Enables quantization if provided (applies to Text Encoder and Transformer)."
    )

    args = parser.parse_args()

    if args.quant_desc_path:
        if not os.path.exists(args.quant_desc_path):
            raise FileNotFoundError(f"Quantization description file not found: {args.quant_desc_path}")
        if not args.quant_desc_path.endswith(".json") or "quant_model_description" not in args.quant_desc_path:
            raise ValueError(f"Invalid quantization file: {args.quant_desc_path}. "
                                "Expected format: 'quant_model_description_*.json'")

    # 分布式初始化
    rank = int(os.getenv("RANK", 0))
    world_size = int(os.getenv("WORLD_SIZE", 1))
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    
    _init_logging(rank)

    if world_size > 1:
        device = f"{args.device}:{local_rank}"
        torch.npu.set_device(local_rank)

        dist.init_process_group(
            backend="hccl",
            init_method="env://",
            rank=rank,
            world_size=world_size)
    else:
        assert not (
            args.t5_fsdp or args.dit_fsdp
        ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
        assert not (
            args.cfg_size > 1 or args.ulysses_size > 1 or args.ring_size > 1
        ), f"context parallel are not supported in non-distributed environments."
        assert not (
            args.vae_parallel
        ), f"vae parallel are not supported in non-distributed environments."

        device = f"{args.device}:{args.device_id}"
        torch.npu.set_device(args.device_id)

    positive_magic = {
        "en": "Ultra HD, 4K, cinematic composition.",
        "zh": "超清,4K,电影级构图"
    }

    # 数据类型配置
    torch_dtype = torch.bfloat16 if args.torch_dtype == "bfloat16" else torch.float32
    
    # 加载模型
    if rank == 0:
        print(f"从 {args.model_path} 加载模型...")

    transformer = QwenImageTransformer2DModel.from_pretrained(
        os.path.join(args.model_path, 'transformer'),
        torch_dtype=torch_dtype,
        device_map=None,  # 禁用自动设备映射,昇腾环境下默认加载到CPU
        low_cpu_mem_usage=True,  # 启用CPU低内存模式,避免加载时CPU内存溢出)
    )

    if args.quant_desc_path:
        from mindiesd import quantize
        if rank == 0:
            print("Quantizing Transformer on NPU (单独量化核心组件)...")
        
        quantize(
            model=transformer,
            quant_des_path=args.quant_desc_path,
            use_nz=True,
        )
        
        torch.npu.empty_cache()  # 清理NPU显存缓存
        gc.collect()  # 触发Python垃圾回收

    pipeline = QwenImagePipeline.from_pretrained(
        args.model_path,
        transformer = transformer,
        torch_dtype=torch_dtype,
        device_map=None,  # 禁用自动设备映射(昇腾环境推荐)
        low_cpu_mem_usage=True,  # 低CPU内存模式
    )

    # VAE优化配置(避免显存溢出)
    pipeline.vae.use_slicing = True
    pipeline.vae.use_tiling = True

    # 移动模型到目标设备
    pipeline.to(device)
    pipeline.set_progress_bar_config(disable=None)  # 显示进度条
    
    # 缓存配置(如果启用)
    if COND_CACHE or UNCOND_CACHE:
        # 保守 cache
        cache_config = CacheConfig(
            method="dit_block_cache",
            blocks_count=60,
            steps_count=args.num_inference_steps,
            step_start=10,
            step_interval=3,
            step_end=35,
            block_start=10,
            block_end=50
        )
        pipeline.transformer.cache_cond = CacheAgent(cache_config) if COND_CACHE else None
        pipeline.transformer.cache_uncond = CacheAgent(cache_config) if UNCOND_CACHE else None
        if rank == 0:
            print("启用缓存配置")

    if args.cfg_size > 1 or args.ulysses_size > 1 or args.ring_size > 1 or args.tp_size > 1:
        assert args.cfg_size * args.ulysses_size * args.ring_size * args.tp_size == world_size, f"The number of cfg_size, ulysses_size and ring_size should be equal to the world size."
        sp_degree = args.ulysses_size * args.ring_size
        parallel_config = ParallelConfig(
            sp_degree=sp_degree,
            ulysses_degree=args.ulysses_size,
            ring_degree=args.ring_size,
            tp_degree=args.tp_size,
            use_cfg_parallel=(args.cfg_size==2),
            world_size=world_size,
        )
        init_parallel_env(parallel_config)
        if args.ulysses_size > 1:
            # 3. 并行化改造Transformer
            parallelize_qwen_image_transformer(pipeline)

    # 5. 启动 Flask 服务(仅 Rank 0 运行 Web 容器)
    if rank == 0:
        logging.info("Starting Flask Online API Server on 10.119.1.59:8080...")
        flask_thread = threading.Thread(target=run_flask_app, daemon=True)
        flask_thread.start()

    # 确保所有分布式 Rank 都准备完毕
    torch.distributed.barrier()
    
    # 6. 常驻显存推理轮询机制 (所有 GPU 共进退)
    while True:
        # 定义广播容器
        sync_data = [None]
        
        # 仅 Rank 0 去读取队列中的请求
        if rank == 0:
            try:
                # 阻塞直到接收到 API 请求
                req = request_queue.get()
                sync_data = [req]
            except Exception as e:
                pass
        
        # 通过 HCCL 广播给所有参与计算的卡,确保8张卡获得完全相同的入参触发计算
        dist.broadcast_object_list(sync_data, src=0)
        
        # 提取参数
        request_params = sync_data[0]
        if not request_params:
            continue
            
        lang = args.prompt_lang
        raw_prompt = request_params["prompt"]
        full_prompt = raw_prompt + " " + positive_magic.get(lang, "")
        negative_prompt = request_params["negative_prompt"]
        width = request_params["width"]
        height = request_params["height"]

        # 推理前清空环境
        torch.npu.empty_cache()
        gc.collect()

        if rank == 0:
            logging.info(f"Start Parallel Inference... Size: {width}x{height}, Prompt: {raw_prompt}")
            start_time = time.time()
            
        try:
            # 执行基于并行环境的 8 卡联合推理
            image = pipeline(
                prompt=full_prompt,
                negative_prompt=negative_prompt,
                width=width,
                height=height,
                num_inference_steps=args.num_inference_steps,
                true_cfg_scale=args.true_cfg_scale,
                generator=torch.Generator(device=device).manual_seed(args.seed),
            ).images[0]

            torch.npu.synchronize()
            
            # 推理成功后,由于 Diffusers 的 Pipeline 输出最终都在所有卡汇聚好,由 Rank 0 来处理返回即可
            if rank == 0:
                end_time = time.time()
                logging.info(f"Inference Completed. Time: {end_time - start_time:.2f} seconds")
                
                # 直接转换图片为 IO 二进制流
                img_byte_arr = io.BytesIO()
                image.save(img_byte_arr, format='PNG')
                img_byte_arr.seek(0)
                
                # 塞进队列,Flask 拿到后返回给用户
                response_queue.put(img_byte_arr)
                
        except Exception as e:
            if rank == 0:
                logging.error(f"Inference error: {str(e)}")
                response_queue.put(e)

        finally:
            # 强制释放当前请求所占用的所有变量
            if 'image' in locals():
                del image
            del request_params, full_prompt, negative_prompt
            torch.npu.empty_cache() 
            gc.collect()

if __name__ == "__main__":
    main()
3) 创建启动脚本

运行 run.sh 脚本,需检查权重路径是否正确,然后启动服务(注意此处的 ASCEND_RT_VISIBLE_DEVICES 为需使用的 NPU 卡号,如下配置使用 8 张卡)

#!/bin/bash
export model_path="/weight/Qwen-Image-2512"

# 显存管理优化:限制最大切分尺寸,防止产生无法利用的显存碎片
export PYTORCH_NPU_ALLOC_CONF=max_split_size_mb:512

export HCCL_EXEC_TIMEOUT=0  # 设置为0表示无限等待,不受1800s限制

export LCCL_DETERMINISTIC=true  
export HCCL_DETERMINISTIC=true 
export ATB_MATMUL_SHUFFLE_K_ENABLE=0 
export ATB_LLM_LCOC_ENABLE=true      
export CLOSE_MATMUL_K_SHIFT=true     

# 算子优化
export ROPE_FUSE=1
export ADALN_FUSE=1

# 算法优化
export COND_CACHE=1
export UNCOND_CACHE=1

# 8卡 cfg=2 ulysses=4
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
torchrun --nproc_per_node=8 --master-port 29508 run_cfg_usp-7.py \
    --model_path ${model_path} \
    --prompt_lang "en" \
    --num_inference_steps 50 \
    --seed 42 \
    --ulysses_size 8 \
    --cfg_size 1
4)启动服务

8卡并行推理服务启动后会看到类似如下的打印:

[root@ww-qwen-image-2512-a3-mindie-worker-0 Qwen-Image]# ./run_cfg_usp-7.sh
[2026-03-05 14:47:05,545] torch.distributed.run: [WARNING] 
[2026-03-05 14:47:05,545] torch.distributed.run: [WARNING] *****************************************
[2026-03-05 14:47:05,545] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2026-03-05 14:47:05,545] torch.distributed.run: [WARNING] *****************************************
从 /weight/Qwen-Image-2512 加载模型...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 30.76it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 29.93it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 29.06it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 27.54it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 29.72it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 28.25it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 27.62it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 26.13it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 22.63it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 22.37it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 20.98it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 21.57it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.28it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 21.38it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.47it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.74it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.31it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.88it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 20.90it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 22.40it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.88it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 22.49it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.78it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.95it/s]
启用缓存配置
[2026-03-05 14:47:21,439] WARNING: Model parallel is not initialized, initializing...
[2026-03-05 14:47:21,444] INFO: Starting Flask Online API Server on 10.119.1.59:8080...
 * Serving Flask app 'run_cfg_usp-7'
 * Debug mode: off
[2026-03-05 14:47:21,448] INFO: WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.
 * Running on http://10.119.1.59:8080
[2026-03-05 14:47:21,448] INFO: Press CTRL+C to quit

六、在线生成服务测试

安装图片生成测试工具 yum install curl jq

使用ctrl命令进行文生图测试: 可参考如下命令,输入prompt,输出 output.png 图片。

   curl -X POST http://10.119.1.59:8080/generate \
    -H "Content-Type: application/json" \
    -d '{
           "prompt": "A futuristic city under a neon sky.",
           "width": 1024,
           "height": 1024
           }' \
   -o output_image.png
1) 运行结果示例