Thera-RDN-Pro 是图像超分辨率模型的高精度变体,它以 Residual Dense Network (RDN) 为骨干网络,并配备了 SwinIR Transformer 尾部(包含 13 个 Swin Transformer Block,分为 2 个 RSTB 阶段)。该模型结合超网络与神经热场,可实现连续尺度的上采样。与 Plus 变体相比,Pro 版本采用 Transformer 架构,能够提取更丰富的全局特征。
Input (RGB)
│
├─ Normalize (MEAN/VAR)
│
├─ RDN Encoder (16 RDBs)
│
├─ SwinIR Tail
│ ├─ Conv2d (64→180)
│ ├─ LayerNorm
│ ├─ RSTB-0: 7× SwinTransformerBlock
│ ├─ RSTB-1: 6× SwinTransformerBlock
│ ├─ Conv2d (180→180) + residual
│ └─ Conv2d (180→64) + LeakyReLU
│
├─ Hypernetwork Conv2d(64 → 2048)
└─ Neural Heat Field| 组件 | 版本 |
|---|---|
| Python | 3.11+ |
| PyTorch | ≥ 2.0.0 |
| torch_npu | ≥ 2.9.0 |
| Ascend NPU | Ascend910 |
pip install torch torchvision numpy Pillowpython3 -c "
from modelscope import snapshot_download
snapshot_download('prs-eth/thera-rdn-pro', cache_dir='/tmp/modelscope')
"python3 inference.py input.png output.png --device npu:0 --scale 2.0python3 compare_cpu_npu.py --scale 2.0测试条件:输入 256×256 RGB,缩放 2.0×(输出 512×512),Ascend910 对比 AMD EPYC 64vCPU
| 指标 | 值 |
|---|---|
| CPU 耗时 | 95.89s |
| NPU 耗时 | 0.47s |
| 加速比 | 204× |
| 相对误差 | 0.9670%(< 1% ✔) |
| PSNR | 25.00 dB |
| SSIM | 0.9937 |
注:Pro 版本包含注意力机制,NPU/CPU 数值差异略大于纯卷积的 Plus 变体,但仍 < 1%。
import torch
import numpy as np
from PIL import Image
from thera_rdn_pro import TheraRDNPro, load_jax_checkpoint_pro
device = torch.device('npu:0')
model = TheraRDNPro().to(device)
load_jax_checkpoint_pro(model, '/tmp/modelscope/prs-eth/thera-rdn-pro/model.pkl')
model.eval()
img_np = np.asarray(Image.open('input.png').convert('RGB')) / 255.0
img = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).float().to(device)
with torch.no_grad():
out = model(img, scale=2.0)
out_np = out.squeeze(0).permute(1, 2, 0).cpu().numpy().clip(0, 1)
Image.fromarray((out_np * 255).astype(np.uint8)).save('output.png')| 层类型 | JAX 形状 | PyTorch 形状 | 转换 |
|---|---|---|---|
| Conv2D | (H,W,C_in,C_out) | (C_out,C_in,H,W) | permute(3,2,0,1) |
| Dense | (C_in,C_out) | (C_out,C_in) | permute(1,0) |
| LayerNorm | scale, bias | weight, bias | 直接拷贝 |
| relative_position_index | 无 | (64,64) | 由 window_size 计算 |
注意:13 个 relative_position_index 是 PyTorch SwinTransformer 的缓冲参数,由 window_size=8 自动计算,不载入 JAX 权重。
thera-rdn-pro/
├── inference.py # NPU/CPU 推理入口
├── compare_cpu_npu.py # CPU vs NPU 精度对比
├── thera_rdn_pro.py # 模型定义 + 权重加载
├── requirements.txt # 依赖列表
└── README.md # 本文件本仓库提供完整的推理脚本,支持 CPU 和 NPU 双平台推理:
# NPU 推理
python3 inference.py --device npu
# CPU 推理
python3 inference.py --device cpu推理完成后会输出推理结果和耗时,表明模型在 NPU 上推理成功。