Ascend-SACT/DPOT
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

DPOT模型昇腾NPU部署指导

一、模型介绍

1.1 模型概述

DPOT(Auto-Regressive Denoising Operator Transformer,自回归去噪算子 Transformer)是面向大规模偏微分方程(PDEs)预训练与算子学习设计的神经算子模型,核心落地场景为地震波全波形反演(FWI)——作为油气资源勘探、地下结构成像的核心技术方向,DPOT旨在替代传统FWI流程中的CBS求解器,解决其迭代收敛慢、计算成本高、泛化性有限的行业痛点。

1.2 核心模块说明

基于提供的源码,DPOT模型核心模块包括:

模块名功能说明
AFNO2D傅里叶域注意力模块,在频域实现高效的空间特征混合,支持分块并行计算
PrjConvTP分布式3D卷积模块,支持昇腾多卡环境下按D维度切分输入,降低单卡显存占用
Block模型基础块,整合AFNO/SlimAFNO特征混合与MLPConv(3D卷积MLP),支持残差连接
PatchEmbed3D/2D补丁嵌入模块,实现输入数据的维度映射与分布式切分
DPOTNet模型主类,整合输入平滑、补丁嵌入、位置编码、特征提取、输出解码全流程

1.3 关键特性

  • 支持3D空间数据处理,适配昇腾多卡分布式训练/推理;
  • 基于MindSpore框架开发,迁移适配昇腾NPU算子;
  • 采用张量并行(TP)技术切分3D卷积维度,适配昇腾910B多卡算力;
  • 融合傅里叶域注意力与卷积网络,兼顾全局特征与局部特征提取。

二、部署环境

2.1 关键组件信息

项目规格参数
计算卡Ascend 910B3
驱动固件25.2.0
CANN 版本8.5.0
操作系统openEuler 22.03
架构aarch64
MindSpore2.8.0
MindSpeedMindSpeed-LLM master
Megatroncore_v0.12.1
Python3.10.19
torch2.7.1
torch_npu2.7.1.post2

2.2 安装驱动固件

请根据系统和硬件产品型号选择对应版本的社区版本或商用版本的驱动与固件。参考如下命令安装:

chmod +x Ascend-hdk-<chip_type>-npu-driver_<version>_linux-<arch>.run
chmod +x Ascend-hdk-<chip_type>-npu-firmware_<version>.run
./Ascend-hdk-<chip_type>-npu-driver_<version>_linux-<arch>.run --full --force
./Ascend-hdk-<chip_type>-npu-firmware_<version>.run --full

2.3 安装CANN

获取CANN,安装配套版本的Toolkit、ops和NNAL并配置CANN环境变量,具体请参考CANN软件安装指南

#基于PyTorch框架设置环境变量
source /usr/local/Ascend/cann/set_env.sh # 修改为实际安装的Toolkit包路径
source /usr/local/Ascend/nnal/atb/set_env.sh # 修改为实际安装的nnal包路径

2.4 安装Pytorch及torch_npu

# 安装torch和torch_npu构建参考 https://gitcode.com/ascend/pytorch/releases
pip3 install torch-2.7.1-cp310-cp310-manylinux_2_28_aarch64.whl 
pip3 install torch_npu-2.7.1rc1-cp310-cp310-manylinux_2_28_aarch64.whl

2.5 安装MindSpore

pip install mindspore==2.8.0 -i https://repo.mindspore.cn/pypi/simple --trusted-host repo.mindspore.cn --extra-index-url https://repo.huaweicloud.com/repository/pypi/simple

2.6 安装MindSpeed LLM

安装MindSpeed加速库

git clone https://gitcode.com/ascend/MindSpeed.git
cd MindSpeed
git checkout master  # checkout commit from MindSpeed master
pip3 install -r requirements.txt 
pip3 install -e .
cd ..

准备MindSpeed LLM及Megatron-LM源码

git clone https://gitcode.com/ascend/MindSpeed-LLM.git 
git clone https://github.com/NVIDIA/Megatron-LM.git  # 从github下载Megatron-LM,请确保网络能访问
cd Megatron-LM
git checkout core_v0.12.1
cp -r megatron ../MindSpeed-LLM/
cd ../MindSpeed-LLM
git checkout master

pip3 install -r requirements.txt  # 安装其余依赖库

三、核心适配点

3.1 框架基础替换

所有网络类继承nn.Cell(MindSpore 核心基类),替代 PyTorch 的nn.Module

# L43:AFNO2D继承nn.Cell(替代nn.Module)
class AFNO2D(nn.Cell):
    def __init__(...):
        super().__init__()
    # L79:construct替代forward
    def construct(self, x, spatial_size = None):

# L627:DPOTNet主网络继承nn.Cell
class DPOTNet(nn.Cell):
    # L691:construct作为前向入口
    def construct(self, x):

3.2 分布式 3D 卷积适配(昇腾多卡新增逻辑)

# L247-L330:PrjConvTP类(昇腾多卡3D卷积切分)
class PrjConvTP(nn.Cell):
    def __init__(self, in_chans, embed_dim, out_dim, k_size, stride, padding, act):
        super().__init__()
        # 卷积padding设为0,手动pad(适配昇腾NPU)
        self.c1 = mint.nn.Conv3d(in_chans, embed_dim, kernel_size=k_size, stride=stride, padding=0, bias=False)
        self.rank = get_rank()  # MindSpore分布式接口
        self.world_size = get_world_size()

    def construct(self, x):
        # 手动replicate padding(替代PyTorch卷积内置padding)
        x_pad = mint.nn.functional.pad(x, (self.pad_w, self.pad_w, self.pad_h, self.pad_h, self.pad_d, self.pad_d), mode="replicate")
        # 按rank切分D维输入
        local_D_out = D_out // self.world_size
        d_out_start = self.rank * local_D_out
        d_out_end = (self.rank + 1) * local_D_out
        # 分布式通信拼接输出
        mint.distributed.all_gather(gather_list, y2_local)
        y_full = mint.cat(gather_list, dim=2)
        return y_full

3.3 重计算机制(昇腾显存优化)

# L511-L515:OutConv中启用重计算(PyTorch无原生支持)
self.conv1.recompute()
self.act.recompute()
self.r_layer.recompute()
self.conv2.recompute()
self.conv3.recompute()

3.4 频域算子替换

# L89-L91:替换torch.fft.rfft2为RDFTn
xr, xi = RDFTn((x.shape[1],x.shape[2]), dim=(1,2), norm="ortho")(x)
# x = torch.fft.rfft2(x, dim=(1, 2))  # 原PyTorch代码

# L131-L133:替换torch.fft.irfft2为IRDFTn
#x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm="ortho") # 原PyTorch代码
x = IRDFTn(shape=(H, W), dim=(1,2), norm="ortho")(o2_real, o2_imag)

四、启动训练/推理

4.1 训练主脚本

import os
import argparse
import mindspore as ms
import mindspore.nn as nn
from mindspore.communication import init, get_rank, get_world_size
from mindspore import context, Tensor, Parameter
from mindspore.train import Model, Callback, LossMonitor, TimeMonitor
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.common import initializer

# 导入自定义模型
from dpot_ms_3d import DPOTNet

# 分布式训练配置
def parse_args():
    parser = argparse.ArgumentParser(description="DPOTNet 4-card distributed training")
    parser.add_argument('--device_target', type=str, default='GPU', choices=['GPU', 'Ascend'])
    parser.add_argument('--distribute', action='store_true', default=True, help='Use distributed training')
    parser.add_argument('--epoch_size', type=int, default=100, help='Total training epochs')
    parser.add_argument('--batch_size', type=int, default=2, help='Batch size per card')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--img_size', type=tuple, default=(124, 124, 124), help='3D input size')
    parser.add_argument('--patch_size', type=int, default=4, help='Patch size for embedding')
    parser.add_argument('--save_ckpt_dir', type=str, default='./ckpt_dist', help='Checkpoint save dir')
    return parser.parse_args()

# 自定义损失函数(根据任务适配,示例为MSE + 分类损失)
class DPOTLoss(nn.Cell):
    def __init__(self, out_channels, n_cls):
        super().__init__()
        self.reg_loss = nn.MSELoss()
        self.cls_loss = nn.CrossEntropyLoss()
        self.out_channels = out_channels
        self.n_cls = n_cls

    def construct(self, pred, label, cls_pred, cls_label):
        # 回归损失(3D输出)
        reg_loss = self.reg_loss(pred, label)
        # 分类损失
        cls_loss = self.cls_loss(cls_pred, cls_label)
        # 总损失
        total_loss = reg_loss + 0.1 * cls_loss
        return total_loss

# 构建数据加载器(示例,需替换为真实数据逻辑)
def create_dataset(batch_size, rank, world_size, img_size=(124,124,124)):
    """
    分布式数据加载:按rank切分数据,确保每张卡加载不同样本
    """
    # 模拟3D数据 [B, D, H, W, C]
    batch_data = Tensor(np.random.randn(batch_size, *img_size, 1), ms.float32)
    batch_label = Tensor(np.random.randn(batch_size, *img_size, 2), ms.float32)
    cls_label = Tensor(np.random.randint(0, 12, (batch_size,)), ms.int32)
    
    # 真实场景需替换为:
    # 1. 读取数据集文件列表,按rank%world_size切分
    # 2. 使用mindspore.dataset构建迭代器
    # 示例返回单批次数据(训练循环中需循环迭代)
    return batch_data, batch_label, cls_label

def main():
    args = parse_args()

    # 1. 初始化分布式环境
    if args.distribute:
        # 初始化GPU/Ascend分布式环境
        context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
        init(backend_name='nccl' if args.device_target == 'GPU' else 'hccl')
        rank_id = get_rank()
        world_size = get_world_size()
        print(f"Distributed init success: rank={rank_id}, world_size={world_size}")
    else:
        rank_id = 0
        world_size = 1

    # 2. 构建模型
    model = DPOTNet(
        img_size=args.img_size,
        patch_size=args.patch_size,
        mixing_type='slim_afno',  # 推荐用slim_afno更高效
        in_channels=1,
        out_channels=2,
        in_timesteps=1,
        out_timesteps=1,
        n_blocks=4,
        embed_dim=1024,
        out_layer_dim=128,
        depth=8,
        modes=32,
        mlp_ratio=2.0,
        n_cls=12,
        normalize=True,
        act='gelu',
        time_agg='exp_mlp',
        spatial_dims=3,
        slim_rank=8,
        slim_num_bands=2,
        slim_use_lora=False
    )

    # 3. 构建损失函数和优化器
    loss_fn = DPOTLoss(out_channels=2, n_cls=12)
    optimizer = nn.Adam(model.trainable_params(), learning_rate=args.lr)

    # 4. 构建训练模型
    train_model = Model(model, loss_fn=loss_fn, optimizer=optimizer, amp_level="O2")

    # 5. 回调函数配置(仅主卡保存ckpt)
    callbacks = [TimeMonitor(), LossMonitor()]
    if rank_id == 0:
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=100,
            keep_checkpoint_max=10,
            integrated_save=False
        )
        ckpt_cb = ModelCheckpoint(
            prefix="dpot_net",
            directory=os.path.join(args.save_ckpt_dir, f"rank_{rank_id}"),
            config=ckpt_config
        )
        callbacks.append(ckpt_cb)

    # 6. 训练循环
    for epoch in range(args.epoch_size):
        if rank_id == 0:
            print(f"===== Epoch {epoch+1}/{args.epoch_size} =====")
        
        # 加载批次数据(真实场景需替换为数据集迭代器)
        train_data, train_label, cls_label = create_dataset(
            batch_size=args.batch_size,
            rank=rank_id,
            world_size=world_size,
            img_size=args.img_size
        )

        # 单步训练
        train_model.train(
            epoch=1,
            train_dataset=[(train_data, train_label, cls_label)],
            callbacks=callbacks,
            dataset_sink_mode=False  # 若用真实Dataset,设为True提升性能
        )

    if rank_id == 0:
        print("Training finished!")

if __name__ == "__main__":
    main()

4.2 推理主脚本

import os
import argparse
import numpy as np
import mindspore as ms
import mindspore.numpy as mnp
from mindspore import context, Tensor, load_checkpoint, load_param_into_net
from mindspore import dtype as mstype
from mindspore.train import Model
from mindspore.nn import Softmax

# 导入自定义模型
from dpot_ms_3d import DPOTNet

# -------------------------- 配置参数 --------------------------
def parse_args():
    parser = argparse.ArgumentParser(description="DPOTNet 3D推理脚本")
    parser.add_argument('--device_target', type=str, default='NPU', choices=['GPU', 'NPU', 'CPU'])
    parser.add_argument('--device_id', type=int, default=0, help='使用的设备ID(单卡)')
    parser.add_argument('--ckpt_path', type=str, required=True, help='训练好的模型权重路径')
    parser.add_argument('--input_data_path', type=str, required=True, help='输入3D数据路径(npy文件)')
    parser.add_argument('--output_data_path', type=str, default='./infer_result.npy', help='推理结果保存路径')
    parser.add_argument('--img_size', type=int, nargs='+', default=[124,124,124], help='3D输入尺寸 (D,H,W)')
    parser.add_argument('--patch_size', type=int, default=4, help='Patch size for embedding')
    parser.add_argument('--in_channels', type=int, default=1, help='输入通道数')
    parser.add_argument('--out_channels', type=int, default=2, help='输出通道数')
    parser.add_argument('--normalize', action='store_true', default=True, help='是否使用归一化(与训练对齐)')
    return parser.parse_args()

# -------------------------- 数据预处理 --------------------------
def preprocess_data(input_path, img_size, normalize=True):
    """
    数据预处理(与训练时对齐)
    输入:npy文件,维度 [D, H, W, C] 或 [C, D, H, W]
    输出:适配模型输入的Tensor [1, D, H, W, C](batch_size=1)
    """
    # 读取3D数据(示例为npy格式,可替换为其他格式如nii、h5)
    data = np.load(input_path).astype(np.float32)
    print(f"原始输入数据维度: {data.shape}")

    # 维度调整:确保最终为 [D, H, W, C]
    if data.ndim == 3:  # [D,H,W] → 补通道维 [D,H,W,1]
        data = np.expand_dims(data, axis=-1)
    elif data.ndim == 4 and data.shape[0] == 1:  # [C,D,H,W] → 转 [D,H,W,C]
        data = np.transpose(data, (1,2,3,0))
    
    # 尺寸校验(与模型输入对齐)
    assert data.shape[:3] == tuple(img_size), \
        f"输入尺寸 {data.shape[:3]} 与模型要求 {img_size} 不匹配"
    assert data.shape[-1] == args.in_channels, \
        f"输入通道数 {data.shape[-1]} 与模型要求 {args.in_channels} 不匹配"

    # 归一化(与训练时的normalize逻辑对齐)
    if normalize:
        mu = np.mean(data, axis=(0,1,2), keepdims=True)
        sigma = np.std(data, axis=(0,1,2), keepdims=True) + 1e-6
        data = (data - mu) / sigma
        # 保存归一化参数,用于后处理反归一化
        np.save('./norm_params.npy', {'mu': mu, 'sigma': sigma})

    # 增加batch维度 → [1, D, H, W, C]
    data = np.expand_dims(data, axis=0)
    # 转为MindSpore Tensor
    return Tensor(data, dtype=mstype.float32)

# -------------------------- 结果后处理 --------------------------
def postprocess_data(pred, normalize=True):
    """
    结果后处理(反归一化、维度调整)
    """
    # Tensor转numpy
    pred_np = pred.asnumpy()  # [1, D, H, W, out_channels]
    pred_np = np.squeeze(pred_np, axis=0)  # 去掉batch维 → [D,H,W,out_channels]

    # 反归一化(如果训练时开启了normalize)
    if normalize:
        norm_params = np.load('./norm_params.npy', allow_pickle=True).item()
        pred_np = pred_np * norm_params['sigma'] + norm_params['mu']
    
    return pred_np

# -------------------------- 主推理流程 --------------------------
def main():
    args = parse_args()

    # 1. 初始化推理环境
    context.set_context(
        mode=context.GRAPH_MODE,  # 推理用GRAPH_MODE更高效
        device_target=args.device_target,
        device_id=args.device_id,
        enable_graph_kernel=False  # 推理关闭图核优化,避免兼容性问题
    )
    print(f"推理环境初始化完成:{args.device_target} (ID={args.device_id})")

    # 2. 构建模型(参数必须与训练时完全一致!)
    model = DPOTNet(
        img_size=tuple(args.img_size),
        patch_size=args.patch_size,
        mixing_type='slim_afno',
        in_channels=args.in_channels,
        out_channels=args.out_channels,
        in_timesteps=1,
        out_timesteps=1,
        n_blocks=4,
        embed_dim=1024,
        out_layer_dim=128,
        depth=8,
        modes=32,
        mlp_ratio=2.0,
        n_cls=12,
        normalize=args.normalize,
        act='gelu',
        time_agg='exp_mlp',
        spatial_dims=3,
        slim_rank=8,
        slim_num_bands=2,
        slim_use_lora=False
    )

    # 3. 加载训练权重
    if not os.path.exists(args.ckpt_path):
        raise FileNotFoundError(f"权重文件不存在:{args.ckpt_path}")
    param_dict = load_checkpoint(args.ckpt_path)
    load_param_into_net(model, param_dict)
    print(f"成功加载权重:{args.ckpt_path}")

    # 4. 设置模型为推理模式(关闭梯度、批量归一化/ dropout固定)
    model.set_train(False)

    # 5. 数据预处理
    input_tensor = preprocess_data(args.input_data_path, args.img_size, args.normalize)
    print(f"数据预处理完成,输入Tensor维度:{input_tensor.shape}")

    # 6. 模型推理
    # DPOTNet的forward返回 (预测结果, 分类预测)
    pred_out, cls_pred = model(input_tensor)
    print(f"推理完成,预测结果维度:{pred_out.shape},分类预测维度:{cls_pred.shape}")

    # 7. 分类结果处理(可选,如需分类概率)
    softmax = Softmax(axis=-1)
    cls_prob = softmax(cls_pred)
    cls_label = mnp.argmax(cls_prob, axis=-1).asnumpy()[0]  # 取batch中第一个样本的分类标签
    print(f"分类预测标签:{cls_label},分类概率:{cls_prob.asnumpy()[0][cls_label]:.4f}")

    # 8. 结果后处理与保存
    pred_result = postprocess_data(pred_out, args.normalize)
    np.save(args.output_data_path, pred_result)
    print(f"推理结果已保存至:{args.output_data_path}")
    print(f"最终结果维度:{pred_result.shape}")

if __name__ == "__main__":
    main()