Ascend-SACT/WGANModel
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

WGAN模型迁移适配指导

1. 模型概述

WGANWGAN通过引入Wasserstein距离替代传统JS散度,将判别器重构为输出实数的“Critic”,以估算生成分布与真实分布间的距离。其核心优势在于解决了训练不稳定和模式崩溃问题,且损失值可直接反映生成质量。为满足理论要求的梯度有界性,本实现摒弃了原始的权重裁剪法,转而采用梯度惩罚(Gradient Penalty)机制,通常能取得更优的生成效果。

2. 准备运行环境

2.1 软件环境

组件版本
Python3.10.19
PyTorch2.1.0
torch_npu2.1.0.post13
CANN8.1.RC1

2.2 硬件环境

设备型号NPU 配置
Atlas 800T A2单卡

2.3 准备镜像

镜像环境镜像地址
公网swr.cn-southwest-2.myhuaweicloud.com/atelier/pytorch_2_1_ascend:pytorch_2.1.0-cann_8.1.rc1-py_3.10-euler_2.10.11-aarch64-snt9b-20250603154214-4e60e43

2.4 启动镜像

IMAGE_ID=swr.cn-southwest-2.myhuaweicloud.com/atelier/pytorch_2_1_ascend:pytorch_2.1.0-cann_8.1.rc1-py_3.10-euler_2.10.11-aarch64-snt9b-20250603154214-4e60e43
CONTAINER_NAME=wgan
docker run -u root --privileged \
 --name ${CONTAINER_NAME} \
 --device /dev/davinci0 \
 --device /dev/davinci_manager \
 --device /dev/devmm_svm \
 --device /dev/hisi_hdc \
 -v /usr/local/dcmi:/usr/local/dcmi \
 -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
 -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
 -v /etc/ascend_install.info:/etc/ascend_install.info \
 -itd ${IMAGE_ID} /bin/bash

3 运行指导

3.1 创建环境

docker exec -it ${CONTAINER_NAME} bash
conda create -n wgan--clone PyTorch-2.1.0
conda activate wgan

3.2 迁移适配

直接安装deepchem-ascend二进制包,WGAN模型已基于适配代码重新编译成二进制包上传pypi

pip install deepchem-ascend==0.0.5

3.3 测试示例

import deepchem as dc
import numpy as np

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    # helper classes that depend on torch, they need to be in the try/catch block
    class Generator(nn.Module):
        """A simple generator for testing."""

        def __init__(self, noise_input_shape, conditional_input_shape):
            super(Generator, self).__init__()
            self.noise_input_shape = noise_input_shape
            self.conditional_input_shape = conditional_input_shape

            self.noise_dim = noise_input_shape[1:]
            self.conditional_dim = conditional_input_shape[1:]

            input_dim = sum(self.noise_dim) + sum(self.conditional_dim)
            self.output = nn.Linear(input_dim, 1)

        def forward(self, input):
            noise_input, conditional_input = input

            inputs = torch.cat((noise_input, conditional_input), dim=1)
            output = self.output(inputs)
            return output

    class Discriminator_WGAN(nn.Module):
        """A simple discriminator for testing."""

        def __init__(self, data_input_shape, conditional_input_shape):
            super(Discriminator_WGAN, self).__init__()
            self.data_input_shape = data_input_shape
            self.conditional_input_shape = conditional_input_shape

            data_dim = data_input_shape[
                1:]  # Extracting the actual data dimension
            conditional_dim = conditional_input_shape[
                1:]  # Extracting the actual conditional dimension
            input_dim = sum(data_dim) + sum(conditional_dim)

            # Define the dense layers
            self.dense1 = nn.Linear(input_dim, 10)
            self.dense2 = nn.Linear(10, 1)

        def forward(self, input):
            data_input, conditional_input = input
            discrim_in = torch.cat((data_input, conditional_input), dim=1)
            output = F.relu(self.dense1(discrim_in))
            output = self.dense2(output)
            return output
except ModuleNotFoundError:
   print("WGANModel unsupport npu")

def generate_batch(batch_size):
    """Draw training data from a Gaussian distribution, where the mean  is a conditional input."""
    means = 10 * np.random.random([batch_size, 1])
    values = np.random.normal(means, scale=2.0)
    return means, values

def generate_data(gan, batches, batch_size):
    for _ in range(batches):
        means, values = generate_batch(batch_size)
        batch = {gan.data_inputs[0]: values, gan.conditional_inputs[0]: means}
        yield batch

class ExampleWGAN(dc.models.torch_models.WGANModel):

    def get_noise_input_shape(self):
        return (
            100,
            2,
        )

    def get_data_input_shapes(self):
        return [(
            100,
            1,
        )]

    def get_conditional_input_shapes(self):
        return [(
            100,
            1,
        )]

    def create_generator(self):
        noise_dim = self.get_noise_input_shape()
        conditional_dim = self.get_conditional_input_shapes()[0]

        return nn.Sequential(Generator(noise_dim, conditional_dim))

    def create_discriminator(self):
        data_input_shape = self.get_data_input_shapes()[0]
        conditional_input_shape = self.get_conditional_input_shapes()[0]

        return nn.Sequential(
            Discriminator_WGAN(data_input_shape, conditional_input_shape))


gan = ExampleWGAN(learning_rate=0.01, gradient_penalty=0.1)
gan.fit_gan(generate_data(gan, 1000, 100), generator_steps=0.1)
device = gan.device
print(f"模型所在设备: {device}")

means = 10 * np.random.random([1000, 1])
values = gan.predict_gan_generator(conditional_inputs=[means])
print(values.shape)

3.4 运行测试代码

复制上述测试代码保存到test_wgan.py

pytest test_wgan.py

测试结果: image