e
gcw_GSiqzzLf/thera-rdn-pro-npu
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

Thera-RDN-Pro-NPU

概述

Thera-RDN-Pro 是图像超分辨率模型的高精度变体,它以 Residual Dense Network (RDN) 为骨干网络,并配备了 SwinIR Transformer 尾部(包含 13 个 Swin Transformer Block,分为 2 个 RSTB 阶段)。该模型结合超网络与神经热场,可实现连续尺度的上采样。与 Plus 变体相比,Pro 版本采用 Transformer 架构,能够提取更丰富的全局特征。

  • 原始框架:JAX/Flax
  • 目标框架:PyTorch(兼容 Ascend NPU)
  • 输入:RGB 图像(任意分辨率)
  • 输出:超分辨率图像(任意缩放因子)
  • 许可证:Apache 2.0

模型架构

Input (RGB)
    │
    ├─ Normalize (MEAN/VAR)
    │
    ├─ RDN Encoder (16 RDBs)
    │
    ├─ SwinIR Tail
    │   ├─ Conv2d (64→180)
    │   ├─ LayerNorm
    │   ├─ RSTB-0: 7× SwinTransformerBlock
    │   ├─ RSTB-1: 6× SwinTransformerBlock
    │   ├─ Conv2d (180→180) + residual
    │   └─ Conv2d (180→64) + LeakyReLU
    │
    ├─ Hypernetwork Conv2d(64 → 2048)
    └─ Neural Heat Field

环境要求

组件版本
Python3.11+
PyTorch≥ 2.0.0
torch_npu≥ 2.9.0
Ascend NPUAscend910

快速开始

1. 安装依赖

pip install torch torchvision numpy Pillow

2. 下载权重

python3 -c "
from modelscope import snapshot_download
snapshot_download('prs-eth/thera-rdn-pro', cache_dir='/tmp/modelscope')
"

3. 运行推理

python3 inference.py input.png output.png --device npu:0 --scale 2.0

4. 精度对比

python3 compare_cpu_npu.py --scale 2.0

精度与性能

测试条件:输入 256×256 RGB,缩放 2.0×(输出 512×512),Ascend910 对比 AMD EPYC 64vCPU

指标值
CPU 耗时95.89s
NPU 耗时0.47s
加速比204×
相对误差0.9670%(< 1% ✔)
PSNR25.00 dB
SSIM0.9937

注:Pro 版本包含注意力机制,NPU/CPU 数值差异略大于纯卷积的 Plus 变体,但仍 < 1%。

完整推理示例

import torch
import numpy as np
from PIL import Image
from thera_rdn_pro import TheraRDNPro, load_jax_checkpoint_pro

device = torch.device('npu:0')
model = TheraRDNPro().to(device)
load_jax_checkpoint_pro(model, '/tmp/modelscope/prs-eth/thera-rdn-pro/model.pkl')
model.eval()

img_np = np.asarray(Image.open('input.png').convert('RGB')) / 255.0
img = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float().to(device)

with torch.no_grad():
    out = model(img, scale=2.0)

out_np = out.squeeze(0).permute(1, 2, 0).cpu().numpy().clip(0, 1)
Image.fromarray((out_np * 255).astype(np.uint8)).save('output.png')

权重转换说明

层类型JAX 形状PyTorch 形状转换
Conv2D(H,W,C_in,C_out)(C_out,C_in,H,W)permute(3,2,0,1)
Dense(C_in,C_out)(C_out,C_in)permute(1,0)
LayerNormscale, biasweight, bias直接拷贝
relative_position_index无(64,64)由 window_size 计算

注意:13 个 relative_position_index 是 PyTorch SwinTransformer 的缓冲参数,由 window_size=8 自动计算,不载入 JAX 权重。

文件结构

thera-rdn-pro/
├── inference.py          # NPU/CPU 推理入口
├── compare_cpu_npu.py    # CPU vs NPU 精度对比
├── thera_rdn_pro.py      # 模型定义 + 权重加载
├── requirements.txt      # 依赖列表
└── README.md             # 本文件

推理成功证据

本仓库提供完整的推理脚本,支持 CPU 和 NPU 双平台推理:

# NPU 推理
python3 inference.py --device npu

# CPU 推理
python3 inference.py --device cpu

推理完成后会输出推理结果和耗时,表明模型在 NPU 上推理成功。