HuggingFace镜像/DiT
模型介绍文件和版本分析
下载使用量0

模型推理指导

一、模型简介

DiT是一种基于Transformer的扩散模型,全称为Diffusion Transformer,DiT遵循ViT的技术方法。有关DiT模型的更多信息,请参考DiT github。

二、环境准备

表 1 版本配套表

配套版本环境准备指导
Python3.10/3.11-
PyTorch2.9.0-

注意:

  • 该模型也支持torch 2.1.0等版本

2.1 CANN开发套件包+kernel包+MindIE包下载

  1. 环境准备指导
  • CANN开发套件包+kernel包安装
# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。
chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run
chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run
# 校验软件包安装文件的一致性和完整性
./Ascend-cann-toolkit_{version}_linux-{arch}.run --check
./Ascend-cann-kernels-{soc}_{version}_linux.run --check
# 安装
./Ascend-cann-toolkit_{version}_linux-{arch}.run --install
./Ascend-cann-kernels-{soc}_{version}_linux.run --install

# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
  • MindIE包安装
# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。
chmod +x ./Ascend-mindie_${version}_linux-${arch}.run
./Ascend-mindie_${version}_linux-${arch}.run --check

# 方式一:默认路径安装
./Ascend-mindie_${version}_linux-${arch}.run --install
# 设置环境变量
cd /usr/local/Ascend/mindie && source set_env.sh

# 方式二:指定路径安装
./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath}
# 设置环境变量
cd ${AieInstallPath}/mindie && source set_env.sh
  • MindIE SD不需要单独安装,安装MindIE时将会自动安装
  • torch_npu 安装: 下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz
tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz
# 解压后,会有whl包
pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl

2.2 安装所需依赖

2.2.1 下载源码

git clone https://modelers.cn/MindIE/DiT.git

2.2.2 安装环境依赖

cd DiT
pip install -r requirements.txt --no-deps

2.2.3 编译fatik算子plugin

注:只有300I Duo设备下需要执行此步骤

cd pta_plugin
bash build.sh
cd ..

三、模型权重

DiT权重文件下载链接如下,按需下载:

DiT-XL-2-256x256下载链接

DiT-XL-2-512x512下载链接

vae权重文件下载链接如下,按需下载:

# ema
git clone https://huggingface.co/stabilityai/sd-vae-ft-ema
# mse
git clone https://huggingface.co/stabilityai/sd-vae-ft-mse

四、模型推理

4.1 性能测试

  1. 开启cpu高性能模式
echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
sysctl -w vm.swappiness=0
sysctl -w kernel.numa_balancing=0
  1. 安装绑核工具
apt-get update
apt-get install numactl

查询卡的NUMA node

lspci -vs bus-id

bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字

可通过lscpu获得NUMA node对应的CPU核数

NUMA node0: 0-23
NUMA node1: 24-47
NUMA node2: 48-71
NUMA node3: 72-95

当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 2. 执行推理命令:

# Atlas 800I A2
numactl -C 0-23 python3 sample.py \
   --vae mse \
   --image_size 512 \
   --ckpt ./DiT-XL-2-512x512.pt \
   --vae_model ./sd-vae-ft-mse \
   --device_id 0 \
   --class_label 0
  
# Atlas 300I Duo
export CPU_AFFINITY_CONF=1
export TASK_QUEUE_ENABLE=1
numactl -C 0-23 torchrun --nproc_per_node=2 sample.py \
   --vae mse \
   --image_size 512 \
   --ckpt ./DiT-XL-2-512x512.pt \
   --vae_model ./sd-vae-ft-mse \
   --class_label 0 \
   --parallel

参数说明:

  • --vae:使用哪种vae模型,支持mse和ema
  • --image_size:分辨率,支持256和512。默认为512
  • --ckpt:DiT-XL-2的权重路径
  • --vae_model: vae模型的权重路径,注意和vae参数要匹配
  • --device_id:使用哪张卡
  • --class_label:可在0~999中任意指定一个整数,代表image_net的种类
  • --parallel:【可选】模型使用并行进行推理

执行完成后会在当前路径生成sample.png

4.2 精度测试

  1. 下载数据集 ImageNet512x512(VIRTUAL_imagenet512.npz)和ImageNet256x256(VIRTUAL_imagenet256_labeled.npz)

  2. 使用脚本读取数据集,生成图片

# Atlas 300I Duo
torchrun --nproc_per_node=2 fid_test.py \
   --vae mse \
   --image_size 512 \
   --ckpt ./DiT-XL-2-512x512.pt \
   --vae_model ./sd-vae-ft-mse \
   --parallel \
   --results results

# Atlas 800I A2
   python3 fid_test.py \
   --vae mse \
   --image_size 512 \
   --ckpt ./DiT-XL-2-512x512.pt \
   --vae_model ./sd-vae-ft-mse \
   --device_id 0 \
   --results results

参数说明:

  • --results:生成的1000张图片存放路径
  • image_size:分辨率,支持256和512。默认为512
  1. 计算FID:
# 512分辨率使用VIRTUAL_imagenet512.npz数据集
python3 -m pytorch_fid ./VIRTUAL_imagenet512.npz ./results
# 256分辨率使用VIRTUAL_imagenet256_labeled.npz数据集
python3 -m pytorch_fid ./VIRTUAL_imagenet256_labeled.npz ./results 

五、推理结果参考

模型推理性能

待测试

声明

  • 本代码仓提到的数据集和模型仅作为示例,这些数据集和模型仅供您用于非商业目的,如您使用这些数据集和模型来完成示例,请您特别注意应遵守对应数据集和模型的License,如您因使用数据集或模型而产生侵权纠纷,华为不承担任何责任。
  • 如您在使用本代码仓的过程中,发现任何问题(包括但不限于功能问题、合规问题),请在本代码仓提交issue,我们将及时审视并解答。