Ascend-SACT/SAM3
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

SAM3模型推理使用指导

作者简介

gaopengcheng,liushucheng

模型概述及场景

SAM3模型概述:

  1. SAM3(Segment Anything with Concepts)属于多模态视觉基础模型,是 SAM 系列分割模型的新一代迭代版本,核心聚焦图像与视频领域的可提示概念分割任务,同时延伸支持 3D 感知与重建能力,在医疗影像、机器人、增强现实等领域具备落地潜力。该模型突破了前代依赖点、框等视觉提示的单一实例分割局限,实现了检测、分割、跟踪三大视觉任务的统一建模,是通用视觉感知领域的重要进阶模型。
  2. SAM3 以概念提示为核心交互方式,可接收文本名词短语、图像样例或二者组合作为输入,返回匹配对象的分割掩码与唯一实例标识,完成全局概念的多实例分割与跨帧跟踪。模型采用共享骨干网络的检测器 - 跟踪器架构,通过创新的存在头机制实现识别与定位解耦,大幅提升开放词汇场景下的检测精度;同时设计检测 - 传播 - 匹配流程,解决视频跟踪中的漂移与遮挡问题。

SAM3模型源码链接:

https://github.com/facebookresearch/sam3

SAM3模型属于推理/训练场景:

本案例主要分享的是SAM3模型的推理场景适配。

准备运行环境

AI框架及版本

本案例使用的AI框架是torch_npu。

框架版本环境准备指导
Python3.11可以用conda安装;论文要求的是3.8-3.12。
torch2.7.1论文要求是2.7以上,在昇腾社区下载对应版本。
torch_npu2.7.1.post2在昇腾社区下载对应版本。
torchvision0.22.1
triton_ascend3.2.0源码编译安装

环境准备

  1. 设备支持:A2
  2. 部署卡类型信息:910B
  3. 部署方式:单卡
  4. 操作系统:ARM

镜像及组合制作及安装

由于论文要求的Pytorch版本是大于2.7的,其对应的CANN版本要求也高,所以从昇腾社区下载高版本CANN的镜像,下载命令如下:

docker pull --platform=arm64 swr.cn-south-1.myhuaweicloud.com/ascendhub/cann:8.3.rc1.alpha001-910-openeuler22.03-py3.11

运行指导

依赖配置

  1. 为了避免torch版本被覆盖,首先安装torch相关依赖:
# 下载并安装PyTorch框架
wget https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl
pip3 install torch-2.7.1+cpu-cp311-cp311-manylinux_2_28_aarch64.whl

# 下载并安装torch_npu插件
wget https://gitcode.com/Ascend/pytorch/releases/download/v7.3.0-pytorch2.7.1/torch_npu-2.7.1.post2-cp311-cp311-manylinux_2_28_aarch64.whl
pip3 install torch_npu-2.7.1.post2-cp311-cp311-manylinux_2_28_aarch64.whl
  1. 安装代码仓指定的依赖,具体配置在pyproject.toml,需要先指定跳过torch相关的依赖:
cat > constraints.txt <<'EOF'
torch==2.7.1+cpu
torch_npu==2.7.1.post2
torchvision==0.22.1
EOF

如果只需要推理,那么执行

pip install -e . -c constraints.txt

如果需要训练、开发等操作,那么执行

pip install -e ".[dev,notebooks,train]" -c constraints.txt

根据自己的需求增减依赖。注意,如果推理过程中代码报cv的错误,可以尝试将opencv-python卸载,重装同版本的opencv-python-headless。

3.安装其他依赖 安装ffmpeg:

wget https://ffmpeg.org/releases/ffmpeg-4.4.4.tar.xz
tar xf ffmpeg-4.4.4.tar.xz
cd ffmpeg-4.4.4
./configure \
  --prefix=/usr/local \
  --enable-shared --disable-static \
  --disable-programs --disable-doc \
  --disable-avdevice --disable-postproc \
  --disable-network --disable-encoders \
  --disable-muxers --disable-bsfs --disable-devices \
  --disable-everything \
  --enable-avfilter \
  --enable-swscale \
  --enable-decoder=h264,hevc,mpeg4,mpeg2video,vp8,vp9,av1,mjpeg,rawvideo \
  --enable-demuxer=mov,matroska,avi,flv,mpegts,mpegps,rawvideo,h264,hevc \
  --enable-parser=h264,hevc,mpeg4video,mpegvideo,vp8,vp9,av1 \
  --enable-protocol=file \
  --disable-vaapi --disable-vdpau --disable-xlib --disable-sdl2
make -j$(nproc)
make install
ldconfig

源码编译安装decord:

git clone --recursive https://github.com/dmlc/decord.git
cd decord
mkdir build && cd build
cmake .. \
  -DCMAKE_BUILD_TYPE=Release \
  -DUSE_CUDA=0 \
  -DCMAKE_PREFIX_PATH=/usr/local
make -j$(nproc)
make install
ldconfig

python绑定:

pip install "numpy<2"
cd decord/python
python3 setup.py install

模型权重

  1. SAM3权重放在Hugging Face,需要申请访问权限。这里可以从modelscope中获取:
wget https://modelscope.cn/models/facebook/sam3/resolve/master/sam3.pt
  1. 注意点:模型加载权重需要修改源码,让模型加载本地权重,否则默认会去Hugging Face下载,会一直网络报错。修改的源码文件为sam3/model_builder.py;在文件内搜索load_from_hf = True → 改成 False;checkpoint_path = None → 改成 "sam3.pt",会有多处,都需要修改。

数据集

本案例使用COCO-2017 val数据集来做图像分割推理的性能和精度验证;使用SACo数据集来做视频分割推理。

快速适配

本案例提供了修改好的patch,应用之后就可以直接开始推理/微调训练任务了。

git apply sam3_pic_train_adapt.patch

图像分割推理适配

  1. 使用官方源码的demo图像进行推理适配,适配完的推理脚本如下:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
import os
from decord import VideoReader, cpu
import matplotlib.pyplot as plt
from PIL import Image
import time

# 导入 SAM3 相关组件
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import plot_results

device = "npu" if torch.npu.is_available() else "cpu"
print(f"正在使用计算设备: {device}")

# 建议关闭 JIT 静态编译,增加算子适配的灵活性
torch.npu.set_compile_mode(jit_compile=False)

model = build_sam3_image_model().to(device)
model.eval()

processor = Sam3Processor(model)
processor.device = "npu"

# 加载测试图片
image_path = "assets/images/test_image.jpg"
if not os.path.exists(image_path):
    print(f"错误:找不到图片 {image_path},请确认路径。")
    exit()

image = Image.open(image_path)

print("开始设置图像(进行全图编码,首次运行 NPU 编译较慢,请耐心等待)...")

with torch.no_grad():
    # 设置图像
    inference_state = processor.set_image(image)

    print("正在进行文本提示分割推理...")
    # 文本提示分割
    inference_state = processor.set_text_prompt(state=inference_state, prompt="shoe")

# 可视化结果
print("推理完成,准备显示结果。")
plot_results(image, inference_state)
# plt.show()

plt.savefig("sam3_result.jpg", bbox_inches='tight', dpi=300)
print("可视化结果已保存为 sam3_result.jpg!")

推理完成后,效果图如下所示: 原demo图像: 原demo图像 demo图像推理后效果: 原demo图像 2. 使用COCO-2017 val数据集进行评测,评测性能和精度,评测脚本如下:

"""
SAM3 COCO val2017 推理脚本 (昇腾 NPU 适配版)

功能:
  1. 在昇腾 NPU 上加载 SAM3 模型
  2. 遍历 COCO val2017 全部 5000 张图片
  3. 对每张图片用 80 个 COCO 类别名做文本提示推理
  4. 输出 COCO 格式预测 JSON (segm + bbox)
  5. 自动运行 COCO mAP 评估

论文参考指标 (COCO val2017):
  - Box AP:  56.4
  - Box APo: 55.7

用法示例:
  python run_coco_npu_inference.py \
      --coco-ann /path/to/instances_val2017.json \
      --coco-img /path/to/val2017 \
      --output-json sam3_npu_coco_predictions.json \
      --eval
  仅跑100张测试图片:
  python run_coco_npu_inference.py \
    --coco-ann /path/to/annotations/instances_val2017.json \
    --coco-img /path/to/data/val2017 \
    --max-images 100 \
    --warmup 3 \
    --output-json sam3_npu_perf_test.json \
    --eval
"""

import argparse
import json
import os
import sys
import time
from collections import OrderedDict

import cv2
import numpy as np
import torch

import torch_npu
from torch_npu.contrib import transfer_to_npu

if torch.npu.is_available():
    torch.cuda.is_available = lambda: True
    torch.Tensor.cuda = lambda self, *args, **kwargs: self.to("npu")
    torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to("npu")
    torch.cuda.current_device = lambda: 0
    torch.cuda.set_device = lambda device: None

torch.npu.set_compile_mode(jit_compile=False)
torch.npu.config.allow_internal_format = False


from PIL import Image
from pycocotools import mask as mask_util
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from tqdm import tqdm

from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor


def mask_to_box_xywh(mask: np.ndarray):
    """从二进制掩码提取 [x, y, w, h] 格式的外接矩形。"""
    ys, xs = np.where(mask > 0)
    if len(xs) == 0:
        return [0.0, 0.0, 0.0, 0.0]
    x0, x1 = float(xs.min()), float(xs.max())
    y0, y1 = float(ys.min()), float(ys.max())
    return [x0, y0, x1 - x0, y1 - y0]


def encode_mask_rle(mask_uint8: np.ndarray) -> dict:
    """将 uint8 掩码编码为 COCO RLE 格式。"""
    rle = mask_util.encode(np.asfortranarray(mask_uint8))
    rle["counts"] = rle["counts"].decode("utf-8")
    return rle


def run_coco_eval(gt_path: str, pred_path: str, iou_type: str = "segm",
                   img_ids=None):
    """运行标准 COCO mAP 评估并打印结果。"""
    print(f"\n{'=' * 60}")
    print(f"  COCO 评估  (iou_type={iou_type})")
    print(f"{'=' * 60}")

    coco_gt = COCO(gt_path)
    coco_dt = coco_gt.loadRes(pred_path)
    coco_eval = COCOeval(coco_gt, coco_dt, iouType=iou_type)
    if img_ids is not None:
        coco_eval.params.imgIds = img_ids
        print(f"  评估图片数: {len(img_ids)}")
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()

    metric_names = [
        "AP", "AP_50", "AP_75", "AP_small", "AP_medium", "AP_large",
        "AR@1", "AR@10", "AR@100", "AR_small", "AR_medium", "AR_large",
    ]
    results = OrderedDict()
    for name, val in zip(metric_names, coco_eval.stats):
        results[name] = round(val * 100, 2)
        print(f"  {name:12s} = {results[name]:.2f}")

    return results


def parse_args():
    p = argparse.ArgumentParser(description="SAM3 COCO NPU 推理脚本")
    p.add_argument("--coco-ann", type=str, required=True,
                    help="COCO 标注文件路径,如 instances_val2017.json")
    p.add_argument("--coco-img", type=str, required=True,
                    help="COCO 图片目录路径,如 val2017/")
    p.add_argument("--output-json", type=str,
                    default="sam3_npu_coco_predictions.json",
                    help="输出预测 JSON 文件路径")
    p.add_argument("--checkpoint", type=str, default="sam3.pt",
                    help="模型权重路径 (默认: sam3.pt)")
    p.add_argument("--confidence-threshold", type=float, default=0.0,
                    help="Sam3Processor 置信度阈值 (mAP 评估建议设 0)")
    p.add_argument("--score-filter", type=float, default=0.01,
                    help="输出结果最低分数过滤 (低于此值的不写入 JSON)")
    p.add_argument("--max-images", type=int, default=-1,
                    help="最多处理图片数 (-1 表示全部)")
    p.add_argument("--resume-json", type=str, default=None,
                    help="断点续推:已有的中间结果 JSON 路径")
    p.add_argument("--eval", action="store_true",
                    help="推理完成后自动运行 COCO mAP 评估")
    p.add_argument("--eval-only", action="store_true",
                    help="仅评估已有预测文件,不做推理")
    p.add_argument("--resolution", type=int, default=1008,
                    help="SAM3 输入分辨率")
    p.add_argument("--warmup", type=int, default=3,
                    help="性能测量前的预热图片数 (跳过前 N 张的计时)")
    return p.parse_args()

def main():
    args = parse_args()

    # ---- 仅评估模式 ----
    if args.eval_only:
        if not os.path.isfile(args.output_json):
            print(f"错误: 预测文件 {args.output_json} 不存在")
            sys.exit(1)
        with open(args.output_json, "r") as f:
            pred_data = json.load(f)
        eval_img_ids = sorted(set(r["image_id"] for r in pred_data))
        run_coco_eval(args.coco_ann, args.output_json, iou_type="segm", img_ids=eval_img_ids)
        run_coco_eval(args.coco_ann, args.output_json, iou_type="bbox", img_ids=eval_img_ids)
        return

    device = "npu" if torch.npu.is_available() else "cpu"
    print(f"[初始化] 计算设备: {device}")

    # ---- 加载 COCO 数据 ----
    print("[Step 1] 加载 COCO 标注文件 ...")
    coco_gt = COCO(args.coco_ann)
    all_img_ids = sorted(coco_gt.getImgIds())
    if args.max_images > 0:
        img_ids = all_img_ids[:args.max_images]
    else:
        img_ids = all_img_ids
    print(f"  待推理图片数: {len(img_ids)} / {len(all_img_ids)}")

    categories = coco_gt.loadCats(coco_gt.getCatIds())
    class_names = [cat["name"] for cat in categories]
    name_to_cat_id = {cat["name"]: cat["id"] for cat in categories}
    print(f"  类别数: {len(class_names)}")

    # ---- 断点续推 ----
    done_img_ids = set()
    results = []
    if args.resume_json and os.path.isfile(args.resume_json):
        print(f"[断点续推] 加载已有结果: {args.resume_json}")
        with open(args.resume_json, "r") as f:
            results = json.load(f)
        done_img_ids = {r["image_id"] for r in results}
        print(f"  已完成图片数: {len(done_img_ids)}")

    remaining_ids = [i for i in img_ids if i not in done_img_ids]
    if len(remaining_ids) == 0:
        print("所有图片推理已完成,无需继续。")
    else:
        # ---- 加载模型 ----
        print("[Step 2] 加载 SAM3 模型 ...")
        model = build_sam3_image_model().to(device)
        state_dict = torch.load(args.checkpoint, map_location="cpu")
        if "model" in state_dict:
            state_dict = state_dict["model"]
        model.load_state_dict(state_dict, strict=False)
        model.eval()

        processor = Sam3Processor(model)
        processor.device = device

        # ---- 推理循环 ----
        print(f"[Step 3] 开始推理 ({len(remaining_ids)} 张图片, {len(class_names)} 个类别) ...")
        print(f"  预热图片数: {args.warmup} (前 {args.warmup} 张不计入性能统计)")

        total_start = time.time()
        save_interval = 500

        # 性能统计收集器(跳过 warmup)
        perf_encode_times = []
        perf_prompt_times = []
        perf_postproc_times = []
        perf_e2e_times = []

        with torch.no_grad(), torch.autocast(device_type="npu", dtype=torch.float16):
            for idx, img_id in enumerate(tqdm(remaining_ids, desc="NPU 推理中")):
                img_start = time.time()
                img_info = coco_gt.loadImgs(img_id)[0]
                img_path = os.path.join(args.coco_img, img_info["file_name"])

                try:
                    image = Image.open(img_path).convert("RGB")
                    orig_w, orig_h = image.size
                except Exception as e:
                    tqdm.write(f"  [跳过] 无法读取图片 {img_path}: {e}")
                    continue

                is_warmup = idx < args.warmup

                # ---- 阶段 1: 图像编码 ----
                torch.npu.synchronize()
                t0 = time.time()
                inference_state = processor.set_image(image)
                torch.npu.synchronize()
                t_encode = time.time() - t0

                # ---- 阶段 2: 文本提示推理 (80 类) ----
                img_results_count = 0
                img_prompt_time = 0.0
                img_postproc_time = 0.0

                for cat_name in class_names:
                    torch.npu.synchronize()
                    tp0 = time.time()
                    cat_state = processor.set_text_prompt(
                        state=inference_state, prompt=cat_name
                    )
                    torch.npu.synchronize()
                    tp1 = time.time()
                    img_prompt_time += (tp1 - tp0)

                    pred_masks = cat_state.get("masks")
                    pred_scores = cat_state.get("scores")

                    if pred_masks is None or len(pred_masks) == 0:
                        continue

                    # ---- 阶段 3: 后处理 ----
                    tpp0 = time.time()

                    if isinstance(pred_masks, torch.Tensor):
                        pred_masks = pred_masks.cpu().numpy()
                    if isinstance(pred_scores, torch.Tensor):
                        pred_scores = pred_scores.cpu().numpy()

                    cat_id = name_to_cat_id[cat_name]

                    for i in range(len(pred_scores)):
                        score = float(pred_scores[i])
                        if score < args.score_filter:
                            continue

                        mask_np = np.squeeze(pred_masks[i]).astype(np.uint8)
                        if mask_np.shape[0] != orig_h or mask_np.shape[1] != orig_w:
                            mask_np = cv2.resize(
                                mask_np, (orig_w, orig_h),
                                interpolation=cv2.INTER_NEAREST
                            )

                        rle = encode_mask_rle(mask_np)
                        bbox = mask_to_box_xywh(mask_np)

                        results.append({
                            "image_id": img_id,
                            "category_id": cat_id,
                            "bbox": [round(b, 2) for b in bbox],
                            "score": round(score, 5),
                            "segmentation": rle,
                        })
                        img_results_count += 1

                    img_postproc_time += (time.time() - tpp0)

                img_elapsed = time.time() - img_start

                # 收集性能数据(跳过 warmup)
                if not is_warmup:
                    perf_encode_times.append(t_encode)
                    perf_prompt_times.append(img_prompt_time)
                    perf_postproc_times.append(img_postproc_time)
                    perf_e2e_times.append(img_elapsed)

                if (idx + 1) % 50 == 0 or idx == 0:
                    elapsed_total = time.time() - total_start
                    avg_per_img = elapsed_total / (idx + 1)
                    eta = avg_per_img * (len(remaining_ids) - idx - 1)
                    avg_prompt = (img_prompt_time / len(class_names) * 1000)
                    tqdm.write(
                        f"  [{idx+1:5d}/{len(remaining_ids)}] "
                        f"{img_info['file_name']:30s} | "
                        f"编码: {t_encode*1000:7.1f}ms | "
                        f"80类推理: {img_prompt_time*1000:7.1f}ms | "
                        f"均值/prompt: {avg_prompt:5.1f}ms | "
                        f"后处理: {img_postproc_time*1000:6.1f}ms | "
                        f"合计: {img_elapsed*1000:7.1f}ms | "
                        f"累计: {len(results):7d} | "
                        f"ETA: {eta/60:.1f}min"
                    )

                # 定期保存中间结果
                if (idx + 1) % save_interval == 0:
                    _tmp = args.output_json + ".partial"
                    with open(_tmp, "w") as f:
                        json.dump(results, f)
                    tqdm.write(f"  [检查点] 已保存 {len(results)} 条结果到 {_tmp}")

        total_elapsed = time.time() - total_start
        print(f"\n[完成] 推理耗时: {total_elapsed:.1f}s "
              f"({total_elapsed/len(remaining_ids)*1000:.1f}ms/img)")

        # ---- 性能统计报告 ----
        if len(perf_encode_times) > 0:
            n = len(perf_encode_times)
            avg_encode = sum(perf_encode_times) / n * 1000
            avg_prompt_total = sum(perf_prompt_times) / n * 1000
            avg_prompt_single = avg_prompt_total / len(class_names)
            avg_postproc = sum(perf_postproc_times) / n * 1000
            avg_e2e = sum(perf_e2e_times) / n * 1000
            avg_e2e_single = avg_encode + avg_prompt_single

            print(f"\n{'=' * 70}")
            print(f"  性能统计 (跳过前 {args.warmup} 张预热, 统计 {n} 张)")
            print(f"{'=' * 70}")
            print(f"  [阶段拆分]")
            print(f"    图像编码 (set_image):           {avg_encode:8.1f} ms/img")
            print(f"    文本推理 (80类 set_text_prompt): {avg_prompt_total:8.1f} ms/img")
            print(f"    单次文本推理 (1 prompt):         {avg_prompt_single:8.1f} ms/prompt")
            print(f"    后处理 (mask→RLE+bbox):          {avg_postproc:8.1f} ms/img")
            print(f"    端到端合计 (80类):               {avg_e2e:8.1f} ms/img")
            print(f"")
            print(f"  [单 prompt 端到端] (编码 + 1次文本推理)")
            print(f"    NPU 实测:   {avg_e2e_single:8.1f} ms")
            print(f"    论文基线:       30.0 ms  (H200 GPU)")
            print(f"    NPU/GPU 比值:   {avg_e2e_single/30.0:8.2f}x")
            print(f"{'=' * 70}")

    # ---- 保存最终结果 ----
    print(f"[Step 4] 保存预测结果 ({len(results)} 条) → {args.output_json}")
    with open(args.output_json, "w") as f:
        json.dump(results, f)
    print("  保存完成!")

    # 清理中间文件
    _partial = args.output_json + ".partial"
    if os.path.isfile(_partial):
        os.remove(_partial)

    # ---- 评估 ----
    if args.eval and len(results) > 0:
        eval_img_ids = sorted(set(r["image_id"] for r in results))
        segm_results = run_coco_eval(args.coco_ann, args.output_json, iou_type="segm", img_ids=eval_img_ids)
        bbox_results = run_coco_eval(args.coco_ann, args.output_json, iou_type="bbox", img_ids=eval_img_ids)

        print(f"\n{'=' * 60}")
        print("  汇总 (论文参考: Box AP=56.4, Box APo=55.7)")
        print(f"{'=' * 60}")
        print(f"  Segm AP   = {segm_results.get('AP', 'N/A')}")
        print(f"  Segm AP50 = {segm_results.get('AP_50', 'N/A')}")
        print(f"  Segm AP75 = {segm_results.get('AP_75', 'N/A')}")
        print(f"  Box  AP   = {bbox_results.get('AP', 'N/A')}")
        print(f"  Box  AP50 = {bbox_results.get('AP_50', 'N/A')}")
        print(f"  Box  AP75 = {bbox_results.get('AP_75', 'N/A')}")


if __name__ == "__main__":
    main()

推理结果展示: COCO数据集评测图像分割推理精度 原demo图像 COCO数据集评测图像分割推理性能 原demo图像

视频分割推理脚本

  1. 使用官方源码的demo视频进行推理适配,适配完的推理脚本如下: *** 注:视频分割推理适配的时候遇到较多适配问题,详见下面常见问题部分。***
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
import os
import cv2
import matplotlib.pyplot as plt
from PIL import Image

from sam3.model_builder import build_sam3_video_predictor
from sam3.visualization_utils import (
    load_frame,
    prepare_masks_for_visualization,
    visualize_formatted_frame_output,
)


def propagate_in_video(predictor, session_id):
    outputs_per_frame = {}
    with torch.no_grad():
        for response in predictor.handle_stream_request(
            request=dict(
                type="propagate_in_video",
                session_id=session_id,
            )
        ):
            outputs_per_frame[response["frame_index"]] = response["outputs"]
    return outputs_per_frame

def read_video_frame(video_path):
    video_frames_for_vis = []
    if isinstance(video_path, str) and video_path.endswith(".mp4"):
        cap = cv2.VideoCapture(video_path)
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            video_frames_for_vis.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        cap.release()
    return video_frames_for_vis

def save_frames_as_mp4(frames, output_path, fps=30):
    if not frames:
        print("没有帧可以用")
        return

    height, width = frames[0].shape[:2]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    if not video_writer.isOpened():
        print(f"错误:无法创建视频文件: {output_path}")
        return

    for i, frame in enumerate(frames):
        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        video_writer.write(frame_bgr)
        print(f"写入帧 {i+1}/{len(frames)}", end='\r')

    video_writer.release()
    print(f"\n视频已保存到:{output_path}")

    if os.path.exists(output_path):
        file_size = os.path.getsize(output_path) / (1024 * 1024)
        print(f"文件大小:{file_size:.2f} MB")
    else:
        print("警告:输出文件未找到")

if __name__ == "__main__":
    print("正在加载 SAM3 视频预测器...")
    predictor = build_sam3_video_predictor("sam3.pt")

    if hasattr(predictor, 'model'):
        predictor.model = predictor.model.to("npu")
    if hasattr(predictor, 'device'):
        predictor.device = "npu"
    if hasattr(predictor, 'tracker'):
        predictor.tracker.device = "npu"

    video_path = "assets/videos/bedroom.mp4"
    if not os.path.exists(video_path):
        print(f"找不到视频文件:{video_path}")
        exit()

    print("正在读取视频帧...")
    video_frames_for_vis = read_video_frame(video_path)

    print("启动推理 Session (NPU 预热中)...")
    with torch.no_grad():
        response = predictor.handle_request(
            request=dict(
                type="start_session",
                resource_path=video_path,
            )
        )
        session_id = response["session_id"]

        prompt_text_str = "person"
        frame_idx = 0
        print(f"在第 {frame_idx} 帧添加文本提示: '{prompt_text_str}'...")
        response = predictor.handle_request(
            request=dict(
                type="add_prompt",
                session_id=session_id,
                frame_index=frame_idx,
                text=prompt_text_str,
            )
        )
        out = response["outputs"]

        print("正在进行视频全局传播推理 (请耐心等待)...")
        outputs_per_frame = propagate_in_video(predictor, session_id)

    print("推理结束,开始生成可视化帧...")
    outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)

    video_frame = []
    plt.ioff()

    for frame_idx in range(0, len(outputs_per_frame)):
        img = visualize_formatted_frame_output(
            frame_idx,
            video_frames_for_vis,
            outputs_list=[outputs_per_frame],
            titles=["SAM 3 Dense Tracking outputs"],
            figsize=(6, 4),
        )
        video_frame.append(img)
        plt.close("all")  

    print("正在保存输出视频...")
    save_frames_as_mp4(video_frame, "sam3_tracking_output.mp4", fps=30)

视频分割推理结果展示: 原demo图像 原demo图像

图像分割训练适配

  1. 注意图像分割训练需要下载完整的依赖,可以回看之前依赖安装部分。
  2. 图像分割训练适配过程修改了较多代码,因此将所有代码修改集中到了sam3_pic_train_adapt.patch,应用后图像的训练和推理任务都能够正常执行。
    git apply sam3_pic_train_adapt.patch
  3. 训练指导
    # 单卡                                                                                                     
    python sam3/train/train.py \                              
        -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
        --use-cluster 0 --num-gpus 1                                                                           
                                                                                                                
    # 多卡                                                     
    python sam3/train/train.py \                                                                               
        -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \                                       
        --use-cluster 0 --num-gpus 4  
  4. 训练精度对比

对比说明

本次对比使用NVIDIA A800进行对比训练,epoch=2,固定seed=42,数据加载关闭多线程,开启确定性计算,需要注意的是,cuda侧有部分算子例如grid_sampler_2d_backward_cuda并没有确定性实现,最终结果存在不可避免的偏差。

对比结果

npu和gpu都跑了20个epoch,结果如下:

评测结果(评测基于训练流程内置的 validation):

metricgpunpunpu-gpurelativewinner
AP0.585637120.598564210.012927092.21%NPU better
AP500.884193420.888768860.004575440.52%NPU better
AP750.638192800.664816110.026623314.17%NPU better
AP_small0.037500000.00000000-0.03750000-100.00%GPU better
AP_medium0.368411950.391290370.022878416.21%NPU better
AP_large0.626752230.637828000.011075771.77%NPU better
AR1000.672944540.689382370.016437832.44%NPU better
AR_medium0.518131570.553186840.035055276.77%NPU better
AR_large0.705562580.716484720.010922141.55%NPU better

异常值说明:AP_small NPU是0,gpu的值也很低,说明验证集中只有很少的小目标,npu并没有检测出来。

训练loss对比:

metricgpunpunpu-gpurelativewinner
train_loss54.05148356.3800372.3285544.31%GPU better
train_core_loss54.05148356.3800372.3285544.31%GPU better
train_bbox0.011454980.011595740.000140761.23%GPU better
train_giou0.110780330.112436970.001656641.50%GPU better
train_ce0.022005340.022520210.000514872.34%GPU better

常见问题

见wiki整理:https://wiki.huawei.com/domains/168301/wiki/348900/WIKI2026032410526941?title=_4a4d28e3