本项目基于 Facebook SAM2.1-Hiera-Large 模型,针对多类别语义分割任务进行了定制化训练。模型支持 4 类分割:
| 类别 ID | 颜色 (RGB) | 描述 |
|---|---|---|
| 0 | (0, 0, 127) | 深蓝色 - 背景 |
| 1 | (0, 212, 255) | 淡蓝色 |
| 2 | (127, 0, 0) | 深红色 |
| 3 | (255, 229, 0) | 黄色 |
模型架构:
训练效果:
docker pull swr.cn-southwest-2.myhuaweicloud.com/atelier/pytorch_ascend:pytorch_2.7.1-cann_8.3.rc1-py_3.11-hce_2.0.2509-aarch64-snt9b-20260329090059-baf3933mkdir -p /data/sam2-train-data
export IMAGE=swr.cn-southwest-2.myhuaweicloud.com/atelier/pytorch_ascend:pytorch_2.7.1-cann_8.3.rc1-py_3.11-hce_2.0.2509-aarch64-snt9b-20260329090059-baf3933
export NAME=sam2
docker run -u root --privileged \
--name $NAME \
--net=host \
--shm-size=16g \
--device /dev/davinci0 \
--device /dev/davinci1 \
--device /dev/davinci2 \
--device /dev/davinci3 \
--device /dev/davinci4 \
--device /dev/davinci5 \
--device /dev/davinci6 \
--device /dev/davinci7 \
--device /dev/davinci_manager \
--device /dev/devmm_svm \
--device /dev/hisi_hdc \
-v /usr/local/dcmi:/usr/local/dcmi \
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
-v /etc/ascend_install.info:/etc/ascend_install.info \
-v /etc/hccn.conf:/etc/hccn.conf \
-v /data/sam2-train-data:/data/sam2-train-data \
-it $IMAGE bash# 设置 Ascend 环境(容器运行可选)
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 设置线程限制(容器运行可选)
export OPENBLAS_NUM_THREADS=1
# 设置 HuggingFace 镜像 (国内)
export HF_ENDPOINT=https://hf-mirror.com数据集采用分层目录结构:
data_dir/
├── JPEGImages/
│ ├── category_1/
│ │ ├── image_001.png
│ │ ├── image_002.png
│ │ └── ...
│ ├── category_2/
│ │ └── ...
│ └── ...
└── Annotations/
├── category_1/
│ ├── image_001.png # 彩色标注
│ ├── image_002.png
│ └── ...
├── category_2/
│ └── ...
└── ...标注图片要求:
python run_train_npu_multiclass.py \
--model facebook/sam2.1-hiera-large \
--data_dir /data/sam2-train-data \
--output_dir ./output_fixed \
--epochs 30 \
--batch_size 1 \
--base_lr 1e-5 \
--vision_lr 1e-5 \
--device npu:0python run_train_npu_multiclass.py \
--checkpoint ./output_fixed/best_model.pt \
--model facebook/sam2.1-hiera-large \
--data_dir /data/sam2-train-data \
--output_dir ./output_fixed \
--epochs 60 \
--batch_size 1 \
--base_lr 1e-5 \
--vision_lr 1e-5 \
--device npu:0nohup python run_train_npu_multiclass.py \
--data_dir /data/sam2-train-data \
--output_dir ./output_fixed \
--epochs 30 \
> train.log 2>&1 &
# 查看训练日志
tail -f train.log# 8 卡分布式训练
torchrun --nproc_per_node=8 \
run_train_npu_multiclass.py \
--model facebook/sam2.1-hiera-large \
--data_dir /data/sam2-train-data \
--output_dir ./output_fixed \
--epochs 30 \
--batch_size 1 \
--base_lr 1e-5 \
--vision_lr 1e-5 \
--num_workers 0torchrun --nproc_per_node=8 \
run_train_npu_multiclass.py \
--checkpoint ./output_fixed/best_model.pt \
--model facebook/sam2.1-hiera-large \
--data_dir /data/sam2-train-data \
--output_dir ./output_fixed \
--epochs 60 \
--batch_size 1 \
--base_lr 1e-5 \
--vision_lr 1e-5 \
--num_workers 0# 创建启动脚本
cat > distributed_train.sh << 'EOF'
#!/bin/bash
source /usr/local/Ascend/ascend-toolkit/set_env.sh
export OPENBLAS_NUM_THREADS=1
export HF_ENDPOINT=https://hf-mirror.com
torchrun --nproc_per_node=8 \
run_train_npu_multiclass.py \
--model facebook/sam2.1-hiera-large \
--data_dir /data/sam2-train-data \
--output_dir ./output_fixed \
--epochs 30 \
--batch_size 1 \
--base_lr 1e-5 \
--vision_lr 1e-5 \
--num_workers 0
EOF
chmod +x distributed_train.sh
# 后台运行分布式训练
nohup ./distributed_train.sh > distributed_train.log 2>&1 &# 使用 4 卡
torchrun --nproc_per_node=4 run_train_npu_multiclass.py ...
# 使用 2 卡
torchrun --nproc_per_node=2 run_train_npu_multiclass.py ...| 参数 | 默认值 | 说明 |
|---|---|---|
--model | facebook/sam2.1-hiera-large | 预训练模型 ID |
--data_dir | 必填 | 数据集目录 |
--output_dir | ./output | 输出目录 |
--epochs | 10 | 训练轮数 |
--batch_size | 1 | 单卡批大小 |
--base_lr | 1e-4 | 基础学习率 (分类器等) |
--vision_lr | 1e-5 | Vision Encoder 学习率 |
--checkpoint | None | 续训 checkpoint 路径 |
--device | npu:0 | 设备 (单卡训练时) |
--local_rank | 0 | 分布式训练自动设置 |
--num_workers | 4 | DataLoader workers 数量,分布式建议设为 0 |
python analyze_training.py --output_dir ./output_fixed/Training Progress Summary:
============================================================
Epoch Loss mIoU Status
------------------------------------------------------------
5 0.2990 0.9037 Saved
10 0.1486 0.9593 Saved
15 0.1363 0.9461 Saved
20 0.1054 0.9590 Saved
25 0.0941 0.9637 Saved
30 0.0912 0.9651 Saved
============================================================
Best Model:
------------------------------------------------------------
File: best_model.pt
Epoch: 30
Loss: 0.0912
mIoU: 0.9651
------------------------------------------------------------
Training Analysis:
------------------------------------------------------------
Loss improvement: 69.5% (0.2990 -> 0.0912)
mIoU improvement: 6.8% (0.9037 -> 0.9651)
Best model: Epoch 30, mIoU = 0.9651, Loss = 0.0912| 指标 | 说明 | 计算方式 |
|---|---|---|
| Loss | 总损失 | CrossEntropy + 0.5×Dice Loss |
| mIoU | 平均交并比 | 所有类别 IoU 的平均值 |
| Epoch Time | 单轮训练时间 | NPU 推理 + 数据加载时间 |
| LR | 当前学习率 | Warmup + Cosine Annealing |
python run_inference_final.py \
--checkpoint ./output_fixed/best_model.pt \
--image /path/to/image.png \
--output mask.png \
--device npu:0python run_inference_final.py \
--checkpoint ./output_fixed/best_model.pt \
--image /path/to/image.png \
--output mask.png \
--point 600,600python run_inference_final.py \
--checkpoint ./output_fixed/best_model.pt \
--image /path/to/image.png \
--output mask.png \
--box 100,100,500,500python run_inference_final.py \
--checkpoint ./output_fixed/best_model.pt \
--image /path/to/image.png \
--output mask.png \
--no_color| 参数 | 说明 |
|---|---|
--checkpoint | 训练好的模型路径 |
--image | 输入图片路径 |
--output | 输出 mask 路径 |
--device | 设备 (npu:0 或 cpu) |
--point | 分割点坐标 (x,y),默认图片中心 |
--box | 边界框 (x_min,y_min,x_max,y_max) |
--no_color | 输出灰度 mask |
--model | 预训练模型 ID |
============================================================
SAM2 NPU Inference - Multi-class (Fixed Version)
============================================================
Loading checkpoint from: ./output_fixed/best_model.pt
Checkpoint epoch: 30
Checkpoint IoU: 0.9651
Number of classes: 4
各类别占比:
深蓝色(背景): 20.7%
淡蓝色: 32.3%
深红色: 26.4%
黄色: 20.6%
Mask saved to: mask.png
============================================================output_fixed/
├── best_model.pt # 最佳模型 (mIoU 最高)
├── checkpoint_epoch_5.pt # 第 5 轮 checkpoint
├── checkpoint_epoch_10.pt # 第 10 轮 checkpoint
├── checkpoint_epoch_15.pt # 第 15 轮 checkpoint
├── checkpoint_epoch_20.pt # 第 20 轮 checkpoint
├── checkpoint_epoch_25.pt # 第 25 轮 checkpoint
├── checkpoint_epoch_30.pt # 第 30 轮 checkpoint
└── ...每个检查点包含:
epoch: 训练轮数model_state_dict: 模型权重optimizer_state_dict: 优化器状态loss: 平均损失iou: 平均 mIoUnum_classes: 类别数量 (4)num_frames: 视频帧数 (1)max_num_objects: 最大对象数 (10)base_lr: 基础学习率vision_lr: 视觉编码器(Vision Encoder)学习率total_epochs: 总训练轮数color_to_class: 颜色映射信息原因:优化器(Optimizer)状态未正确恢复或模型架构不匹配
解决:确保使用 --model facebook/sam2.1-hiera-large 参数
原因:颜色映射未固定
解决:使用修复后的训练脚本重新训练
解决:减小 batch_size 或使用梯度检查点(gradient checkpointing)
解决:设置镜像 export HF_ENDPOINT=https://hf-mirror.com
原因:环境变量未正确设置
解决:确保已执行 source /usr/local/Ascend/ascend-toolkit/set_env.sh
原因:DDP 包装后模型结构与保存时不同
解决:脚本已自动处理,使用 model.module.state_dict() 保存
原因:HCCL 通信库不支持 Float64(Double)类型,仅支持 Float32/Float16
解决:已修复,分布式训练时使用 dtype=torch.float32 创建张量(tensor)
错误示例:
RuntimeError: HCCL allreduce: Unsupported data type at::kDouble
ERR02007 DIST feature not supported修复位置:run_train_npu_multiclass.py 第 350-352 行
# 修复前(错误)
loss_tensor = torch.tensor([avg_loss], device=device) # 默认 Float64
# 修复后(正确)
loss_tensor = torch.tensor([avg_loss], dtype=torch.float32, device=device)
iou_tensor = torch.tensor([avg_iou], dtype=torch.float32, device=device)SAM2-A2/
├── run_train_npu_multiclass.py # 训练脚本 (支持分布式,已修复HCCL问题)
├── run_inference_final.py # 推理脚本
├── analyze_training.py # 训练分析脚本
├── npu_compat.py # NPU 兼容层
├── SAM2_README.md # 本文档
└── output_fixed/ # 输出目录
├── best_model.pt
└── checkpoint_epoch_*.pt| 版本 | 日期 | 更新内容 |
|---|---|---|
| v1.0 | 2026-04-21 | 初始版本,固定颜色映射,4 类分割 |
| v1.1 | 2026-04-21 | 修复续训 optimizer 恢复问题 |
| v1.2 | 2026-04-21 | 添加分布式训练支持 (torchrun) |
| v1.3 | 2026-05-14 | 修复 HCCL Float64 类型不支持问题 |
如有问题,请查看训练日志或使用 analyze_training.py 分析训练状态。