Ascend-SACT/sdxl
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

一、业务背景

扩散模型(Diffusion Models)是一种生成模型,可生成各种各样的高分辨率图像。Diffusers 是 HuggingFace 发布的模型套件,是最先进的预训练扩散模型的首选库,用于生成图像、音频乃至分子的 3D 结构。该套件包含基于扩散模型的多种模型,并提供了各类下游任务的训练与推理实现。

  • 参考实现:

    url=https://github.com/huggingface/diffusers
    commit_id=5956b68a6927126daffc2c5a6d1a9a189defe288

MindSpeed MM 是面向大规模分布式训练的昇腾多模态大模型套件,同时支持多模态生成及多模态理解。本文档提供了在昇腾环境中从环境搭建、数据准备、模型微调到推理验证的全套实践指南,适用于希望利用华为硬件加速 SDXL 模型训练与生成的开发者和研究人员。通过 MindSpeed 套件,用户能够在昇腾芯片上高效完成扩散模型的微调与部署。

二、环境准备

2.1. 硬件环境

硬件名称配置信息备注
机器型号A3 超节点
测试集群2 卡 Pod单机

2.2. 软件环境

软件版本部署方式
DriverAscendHDK 25.2.0宿主机
FirmwareAscendHDK 25.2.0宿主机
Python3.10.18容器
CANN8.2.RC1容器
Torch2.6.0容器
Torch_npu2.6.0容器
MindSpeed2.1.0_core_r0.8.0容器
MindSpeed-LLM2.1.0容器
Megatron-LMcore_r0.8.0容器
Docker 镜像 OSUbuntu 20.04.6

2.3. 镜像准备

SDXL 镜像已发布,使用方法请参考:昇腾训练镜像构建指导 中 SDXL 镜像章节。镜像下载链接:Ascend-SACT/ascend_train_image

启动镜像:

docker run -d --net=host --shm-size=128g --privileged --name=sdxl \
$(for i in {0..15}; do [ -e "/dev/davinci$i" ] && echo --device=/dev/davinci$i; done) \
--device=/dev/davinci_manager \
--device=/dev/hisi_hdc \
--device=/dev/devmm_svm \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
-v /usr/local/dcmi:/usr/local/dcmi \
-e LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64:$LD_LIBRARY_PATH \
aimp-mindspeed-mm2.1.0-pytorch2.6.0-ascend_cann8.2.rc1-py310-sdxl-arm64:0906 \
/bin/bash -c "tail -f /dev/null"


docker start sdxl
docker exec -it sdxl bash

二、准备预训练数据集

用户需自行获取并解压pokemon-blip-captions](https://gitee.com/link?target=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Flambdalabs%2Fpokemon-blip-captions%2Ftree%2Fmain)数据集,并在以下启动shell脚本中将`dataset_name`参数设置为本地数据集的绝对路径

修改train_sdxl_deepspeed_**16.sh的dataset_name为pokemon-blip-captions的绝对路径

vim sdxl/train_sdxl_deepspeed_**16.sh

pokemon-blip-captions数据集格式如下:

pokemon-blip-captions
├── dataset_infos.json
├── README.MD
└── data
      └── train-001.parquet

说明: 该数据集的训练过程脚本只作为一种参考示例。

三、配置 SDXL 训练模型

联网情况下,预训练模型可通过以下步骤下载。无网络时,用户可访问huggingface官网自行下载sdxl-base模型 model_name模型与sdxl-vae模型 vae_name

export model_name="stabilityai/stable-diffusion-xl-base-1.0" # 预训练模型路径
export vae_name="madebyollin/sdxl-vae-fp16-fix" # vae模型路径

获取对应的训练模型后,在以下shell启动脚本中将model_name参数设置为本地训练模型绝对路径,将vae_name参数设置为本地vae模型绝对路径

scripts_path="./sdxl" # 模型根目录(模型文件夹名称)
model_name="stabilityai/stable-diffusion-xl-base-1.0" # 预训练模型路径
vae_name="madebyollin/sdxl-vae-fp16-fix" # vae模型路径
dataset_name="laion_sx" # 数据集路径
batch_size=4
max_train_steps=2000
mixed_precision="bf16" # 混精
resolution=1024
config_file="${scripts_path}/pretrain_${mixed_precision}_accelerate_config.yaml"

# accelerate launch *** \ 目录下
--dataloader_num_workers=8 \ # 请基于系统配置与数据大小进行调整

修改bash文件中accelerate配置下train_text_to_image_sdxl_pretrain.py的路径(默认路径在diffusers/sdxl/)

accelerate launch --config_file ${config_file} \
  ${scripts_path}/train_text_to_image_sdxl_pretrain.py \  #如模型根目录为sdxl则无需修改

修改pretrain_fp16_accelerate_config.yaml的deepspeed_config_file的路径:

deepspeed_config_file: ./sdxl/deepspeed_fp16.json # deepspeed JSON文件路径

修改examples/text_to_image/train_text_to_image_sdxl.py文件

vim examples/text_to_image/train_text_to_image_sdxl.py
  1. 在文件58行修改版本

    # 将minimum version从0.31.0修改为0.30.0
    check_min_version("0.30.0")
  2. 在文件59行添加代码

    from patch_sdxl import TorchPatcher, config_gc
    TorchPatcher.apply_patch()
    config_gc()

四、 微调

4.1 启动微调

【可选】如果是Ubuntu系统,需在 examples/text_to_image/train_text_to_image_lora_sdxl.py 与 examples/controlnet/train_controlnet_sdxl.py 中添加 accelerator.print(""):参考

注意 train_text_to_image_lora_sdxl 在1235行附近添加;train_controlnet_sdxl 在1307行附近添加

【Lora断点推理权重保存】

如果需要保存checkpointing steps中的Lora_weights,必须在代码上方(与sdxl预训练中的patch修改位置相同)添加如下内容:

from patch_sdxl import save_Lora_Weights

并在train_text_to_image_lora_sdxl.py的1227行附近,accelerator.save_state(save_path)下方添加save_Lora_Weights(unwrap_model, unet, text_encoder_one, text_encoder_two, args.train_text_encoder, save_path),如下:

accelerator.save_state(save_path)
save_Lora_Weights(unwrap_model, unet, text_encoder_one, text_encoder_two, args.train_text_encoder, save_path)
logger.info(f"Saved state to {save_path}")

【运行微调的脚本】

```shell
# 单机八卡微调
# finetune_sdxl_controlnet_deepspeed_fp16.sh 中依赖的图片,可以通过下面命令下载
# wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
# wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
bash sdxl/finetune_sdxl_controlnet_deepspeed_fp16.sh      #8卡deepspeed训练 sdxl_controlnet fp16
bash sdxl/finetune_sdxl_lora_deepspeed_fp16.sh            #8卡deepspeed训练 sdxl_lora fp16
bash sdxl/finetune_sdxl_deepspeed_fp16.sh        #8卡deepspeed训练 sdxl_finetune fp16

4.2 查看微调结果

查看训练日志

后续可通过输出到共享目录中的日志查看训练情况,日志路径:

src/train25.10/mm-sdxl/diffusers/logs/train_bf16_sdxl_lora_deepspeed.log

如上表示训练启动成功。

docs_to_md_SDXL训练指导_images_image-20250825182237633

运行完成后,会在/src/train25.10/outputs/sdxl/pokemon-blip-captions目录下生成如下文件,包含输出日志和权重结果

└── train_bf16_sdxl_lora_deepspeed.log
└── pytorch_lora_weights.safetensors
└── StableDiffusionXLLoRADeepspeed_bs5_8p_acc.log
└── │   │   │   events.out.tfevents.1756192224.d7fe100c467a.40152.1
└── │   │   │   hparams.yml
└── │   │   events.out.tfevents.1756192224.d7fe100c467a.40152.0

五、推理验证

5.1 执行文生图推理脚本

执行推理脚本

cd /src/train-25.10/mm-sdxl/diffusers/

source /usr/local/Ascend/ascend-toolkit/set_env.sh

python sdxl/sdxl_text2img_lora_infer_pokemon.py

在脚本的25行可以修改文生图的提示词,用于生成不同的图片

prompts = dict()
prompts["masterpiece, best quality, Cute dragon creature, pokemon style, night, moonlight, dim lighting"] = "deformed, disfigured, underexposed, overexposed, rugged, (low quality), (normal quality),"
prompts["masterpiece, best quality, Pikachu walking in beijing city, pokemon style, night, moonlight, dim lighting"] = "deformed, disfigured, underexposed, overexposed, (low quality), (normal quality),"
prompts["masterpiece, best quality, red panda , pokemon style, evening light, sunset, rim lighting"] = "deformed, disfigured, underexposed, overexposed, (low quality), (normal quality),"
prompts["masterpiece, best quality, Photo of (Lion:1.2) on a couch, flower in vase, dof, film grain, crystal clear, pokemon style, dark studio"] = "deformed, disfigured, underexposed, overexposed, "
prompts["masterpiece, best quality, siberian cat pokemon on river, pokemon style, evening light, sunset, rim lighting, depth of field"] = "deformed, disfigured, underexposed, overexposed, "
prompts["masterpiece, best quality, pig, Exquisite City, (sky:1.3), (Miniature tree:1.3), Miniature object, many flowers, glowing mushrooms, (creek:1.3), lots of fruits, cute colorful animal protagonist, Firefly, meteor, Colorful cloud, pokemon style, Complicated background, rainbow,"] = "Void background,black background,deformed, disfigured, underexposed, overexposed, "
prompts["masterpiece, best quality, (pokemon), a cute pikachu, girl with glasses, (masterpiece, top quality, best quality, official art, beautiful and aesthetic:1.2),"] = "(low quality), (normal quality), (monochrome), lowres, extra fingers, fewer fingers, (watermark), "
prompts["masterpiece, best quality, sugimori ken $style$, (pokemon $creature$), pokemon electric type, grey and yellow skin, mechanical arms, cyberpunk city background, night, neon light"] = "(worst quality, low quality:1.4), watermark, signature, deformed, disfigured, underexposed, overexposed, "

控制台打印日志如下

docs_to_md_SDXL训练指导_images_image-20250827172338322

5.2 文生图推理结果

运行完成后,会在/src/train25.10/outputs/sdxl/pokemon-blip-captions/sdxl_lora_NPU目录下生成如下生成的图片

└──  red panda , p-23.png
└──  Photo of (Lio-8.png
└──  Photo of (Lio-42.png
└──  siberian cat -42.png
└──  (pokemon), a -23.png
└──  Pikachu walki-8.png
└──  red panda , p-8.png
└──  Photo of (Lio-1334.png
└──  sugimori ken -42.png
└──  red panda , p-1334.png
└──  (pokemon), a -1334.png
└──  sugimori ken -8.png
└──  sugimori ken -23.png
└──  red panda , p-42.png
└──  Photo of (Lio-23.png
└──  siberian cat -23.png
└──  (pokemon), a -42.png
└──  Cute dragon c-23.png
└──  pig, Exquisit-8.png
└──  pig, Exquisit-23.png
└──  (pokemon), a -8.png
└──  sugimori ken -1334.png
└──  Cute dragon c-8.png
└──  Cute dragon c-1334.png
└──  Pikachu walki-23.png
└──  Pikachu walki-42.png
└──  Pikachu walki-1334.png
└──  siberian cat -8.png
└──  Cute dragon c-42.png
└──  siberian cat -1334.png
└──  pig, Exquisit-42.png
└──  pig, Exquisit-1334.png

可直接查看生成的图片

image-20251223115452203