Ascend-SACT/3DDiffusionPolicy
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

3D Diffusion Policy - Ascend NPU 适配

一、模型介绍

3D Diffusion Policy (DP3) 是一种基于3D点云的机器人视觉模仿学习方法,将点云表征与扩散策略相结合。

项目信息
论文arXiv:2403.03954, RSS 2024
项目地址https://github.com/YanjieZe/3D-Diffusion-Policy
用途机器人视觉模仿学习
输入点云 (B, N, 3) + 机器人状态 (B, D)
输出机器人动作序列 (B, horizon, action_dim)

核心组件:

  • DP3Encoder: PointNet 点云编码器
  • ConditionalUnet1D: 扩散模型 backbone
  • DDPMScheduler: 去噪调度器

二、环境要求

2.1 硬件环境

型号说明
Ascend 910 (A3)已验证
Ascend 910B兼容

2.2 软件环境

软件名版本说明
CANN8.5.x昇腾软件栈
Python3.12容器提供
PyTorch2.7.1+cpu基础框架
torch-npu2.7.1.post2NPU 支持

2.3 依赖版本

依赖版本说明
diffusers0.11.1需配合 huggingface_hub 0.14.1
hydra-core1.3.2Python 3.12 兼容
zarr3.1.6Python 3.12 兼容
einops0.4.1—

三、快速开始

3.1 创建容器

docker run -d \
  --name dp3_npu \
  --privileged \
  --network host \
  -v /data:/data \
  swr.cn-north-4.myhuaweicloud.com/ascend-sact/ascend-a3-ubuntu:v3.3 \
  /bin/bash -c "while true; do sleep 3600; done"

3.2 进入容器并配置环境

# 进入容器
docker exec -it dp3_npu bash

# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/common:$LD_LIBRARY_PATH
export ASCEND_DEVICE_ID=0

3.3 安装依赖

# 安装 PyTorch + torch-npu
pip3 install torch==2.7.1+cpu --index-url https://download.pytorch.org/whl/cpu
pip3 install torch-npu==2.7.1.post2

# 安装项目依赖
pip3 install huggingface_hub==0.14.1
pip3 install diffusers==0.11.1
pip3 install hydra-core==1.3.2
pip3 install zarr==3.1.6
pip3 install einops==0.4.1
pip3 install tqdm wandb

3.4 克隆项目

cd /data
git clone https://github.com/YanjieZe/3D-Diffusion-Policy.git
cd 3D-Diffusion-Policy

3.5 准备数据

方式一:使用官方数据

pip3 install gdown
gdown "https://drive.google.com/uc?id=1G5MP6Nzykku9sDDdzy7tlRqMBnKb253O" -O /tmp/dp3_data.zip
unzip /tmp/dp3_data.zip -d data/

方式二:生成合成数据(测试用)

python3 scripts/generate_synthetic_data.py

3.6 运行训练

cd 3D-Diffusion-Policy

# 训练
python3 scripts/train_dp3_npu.py --mode train

# 或训练+推理测试
python3 scripts/train_dp3_npu.py --mode both

3.7 运行推理

# 加载训练好的模型进行推理
python3 scripts/train_dp3_npu.py --mode inference

预期输出:

Input point_cloud shape: torch.Size([1, 2, 512, 3])
Input agent_pos shape: torch.Size([1, 2, 24])
Output action shape: torch.Size([1, 16, 26])
Inference test PASSED

四、适配说明

4.1 关键问题与解决方案

问题 1: horizon 维度不匹配 ⚠️

现象: horizon=10 时训练报错

原因: ConditionalUnet1D 使用 Conv1d stride=2 下采样,非 2 的幂次导致 skip connection 维度不匹配

解决: 使用 horizon=16 或 32 或 64

# ✅ 正确
config.horizon = 16

# ❌ 错误
config.horizon = 10

问题 2: Python 3.12 兼容性

现象: hydra-core / zarr 安装失败

解决: 升级到兼容版本

pip3 install hydra-core==1.3.2  # 原始 1.2.0
pip3 install zarr==3.1.6        # 原始 2.12.0

问题 3: diffusers 导入错误

现象: HfFolder 相关错误

解决: 安装兼容版本

pip3 install huggingface_hub==0.14.1
pip3 install diffusers==0.11.1

4.2 Checkpoint 保存

保存时需排除 normalizer 参数:

model_state = {k: v for k, v in model.state_dict().items()
               if not k.startswith('normalizer')}
checkpoint = {
    'model_state_dict': model_state,
    'normalizer_state': normalizer.state_dict(),
    ...
}

五、性能指标

5.1 NPU vs CPU 性能对比

指标NPUCPU加速比
训练吞吐量4.81 iter/s0.17 iter/s28.87x
10步推理93.66ms4524ms48.30x

5.2 推理性能详情

去噪步数NPU 时间说明
5步47ms快速推理
10步94ms默认推荐
20步187ms高质量
50步467ms最高质量

单位步数耗时: ~9.3ms/步

5.3 精度结果

测试项结果
训练 Loss (10 epochs)0.1004
推理稳定性 (方差)0.415
预测 MSE0.5685

六、推荐配置

参数推荐值原因
horizon16 或 32U-Net 维度匹配
n_obs_steps2默认值
batch_size32NPU 内存允许
推理步数10-20平衡速度与质量

七、常见问题

Q1: horizon=10 报错怎么办?

A: 必须使用 2 的幂次 (16, 32, 64),这是 U-Net 架构的限制。

Q2: NPU 内存不足怎么办?

A: 减小 batch_size 或 horizon 参数。

Q3: 如何验证 NPU 可用?

import torch
import torch_npu
print(f"NPU available: {torch.npu.is_available()}")

Q4: 推理结果不稳定?

A: 这是 DDPM 的固有随机性,可通过设置随机种子或使用 DDIM scheduler 改善。

八、文件结构

scripts/
├── train_dp3_npu.py       # NPU 训练脚本
├── test_performance.py    # 性能测试脚本
├── generate_synthetic_data.py  # 合成数据生成
└── setup_env.sh           # 环境配置脚本

data/
├── outputs/
│   └── dp3_npu_full.pt    # 训练模型
└── synthetic_test.zarr/   # 合成数据

九、参考资料

  • 项目地址: https://github.com/YanjieZe/3D-Diffusion-Policy
  • 论文: arXiv:2403.03954, RSS 2024
  • 项目主页: https://3d-diffusion-policy.github.io/
  • CANN 文档: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/
  • torch-npu: https://gitee.com/ascend/pytorch

十、License

本项目适配代码采用 Apache-2.0 许可证,原模型遵循其自身许可证。