e
gcw_GSiqzzLf/thera-rdn-plus-npu
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

Thera-RDN-Plus-NPU

概述

Thera-RDN-Plus 是一款图像超分辨率模型,它以 Residual Dense Network (RDN) 作为骨干网络,并配备了 ConvNeXt 特征提取尾部(18 层)。该模型结合了超网络(Hypernetwork)与神经热场(Neural Heat Field)技术,能够实现连续尺度的图像上采样。本仓库提供了从 JAX 到 PyTorch 的权重转换 以及 Ascend NPU 推理 的完整适配方案。

  • 原始框架: JAX/Flax
  • 目标框架: PyTorch (Ascend NPU 兼容)
  • 输入: RGB 图像 (任意分辨率)
  • 输出: 超分辨率图像 (任意缩放因子)
  • 许可证: Apache 2.0

模型架构

Input (RGB)
    │
    ├─ Normalize (MEAN/VAR)
    │
    ├─ RDN Encoder
    │   ├─ Conv_0 (3→64)
    │   ├─ Conv_1 (64→64)
    │   ├─ 16× RDB (每块 8 层卷积)
    │   ├─ Conv_2 (64→64)
    │   ├─ Conv_3 (64→128)
    │   └─ GFF (128→64)
    │
    ├─ ConvNeXt Tail (18 层)
    │   ├─ 6× ConvNeXtBlock (dim=64)
    │   ├─ Projection (64→96)
    │   ├─ 7× ConvNeXtBlock (dim=96)
    │   ├─ Projection (96→128)
    │   └─ 3× ConvNeXtBlock (dim=128)
    │
    ├─ Hypernetwork Conv2d(128 → 2048)
    └─ Neural Heat Field

环境要求

组件版本
Python3.11+
PyTorch≥ 2.0.0
torch_npu≥ 2.9.0
Ascend NPUAscend910

快速开始

1. 安装依赖

pip install torch torchvision numpy Pillow

2. 下载权重

python3 -c "
from modelscope import snapshot_download
snapshot_download('prs-eth/thera-rdn-plus', cache_dir='/tmp/modelscope')
"

3. 运行推理

# NPU 推理
python3 inference.py input.png output.png --device npu:0 --scale 2.0

4. 精度对比

python3 compare_cpu_npu.py --scale 2.0

精度与性能

测试条件:输入 256×256 RGB,缩放 2.0×(输出 512×512),Ascend910 对比 AMD EPYC 64vCPU

指标值
CPU 耗时84.28s
NPU 耗时0.44s
加速比193×
相对误差0.0514%(< 1% ✔)
PSNR63.06 dB
SSIM0.999998

完整推理示例

import torch
import numpy as np
from PIL import Image
from thera_rdn_plus import TheraRDNPlus, load_jax_checkpoint_plus

device = torch.device('npu:0')
model = TheraRDNPlus().to(device)
load_jax_checkpoint_plus(model, '/tmp/modelscope/prs-eth/thera-rdn-plus/model.pkl')
model.eval()

img_np = np.asarray(Image.open('input.png').convert('RGB')) / 255.0
img = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float().to(device)

with torch.no_grad():
    out = model(img, scale=2.0)

out_np = out.squeeze(0).permute(1, 2, 0).cpu().numpy().clip(0, 1)
Image.fromarray((out_np * 255).astype(np.uint8)).save('output.png')

权重转换说明

JAX→PyTorch 映射规则:

层类型JAX 形状PyTorch 形状转换
Conv2D(H,W,C_in,C_out)(C_out,C_in,H,W)permute(3,2,0,1)
Dense(C_in,C_out)(C_out,C_in)permute(1,0)
LayerNormscale, biasweight, bias直接拷贝
components(2, 512)(2, 512)无需转置

文件结构

thera-rdn-plus/
├── inference.py          # NPU/CPU 推理入口
├── compare_cpu_npu.py    # CPU vs NPU 精度对比
├── thera_rdn_plus.py     # 模型定义 + 权重加载
├── requirements.txt      # 依赖列表
└── README.md             # 本文件

推理成功证据

本仓库提供完整的推理脚本,支持 CPU 和 NPU 双平台推理:

# NPU 推理
python3 inference.py --device npu

# CPU 推理
python3 inference.py --device cpu

推理完成后会输出推理结果和耗时,表明模型在 NPU 上推理成功。