Ascend-SACT/SAM2
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

SAM2-推理指导

概述

Segment Anything Model 2(SAM 2)是一个用于解决图像和视频中可提示视觉分割问题的基础模型。该模型将SAM扩展到视频领域,将图像视为单帧视频。它采用简洁的Transformer架构,并利用流式内存实现实时视频处理。

准备工作

使用SAM 2前需要先安装它。该代码需要python>=3.10,torch>=2.5.1,torchvision>=0.20.1。可以使用以下命令在机器上安装SAM 2:

Canda创建环境并激活:

Conda创建环境并激活
conda create -n sam2 python=3.10.0
conda activate sam2

sam2代码

git clone https://github.com/facebookresearch/sam2.git && cd sam2
pip install -e .

权重准备

cd checkpoints && \
./download_ckpts.sh && \
cd ..

配置文件

将配置文件拷贝到项目根目录:

cp -r sam2/configs/ ./

安装PyTorch及其他依赖

可通过wheel格式的二进制软件包直接安装。参考文档: https://www.hiascend.com/document/detail/zh/Pytorch/710/configandinstg/instg/insg_0004.html

此外,还需依赖torchvision,其版本需与torch版本对应,参考: https://www.hiascend.com/document/detail/zh/Pytorch/710/configandinstg/instg/insg_0010.html

若版本不对应,将会出现如下报错:

RuntimeError: operator torchvision::nms does not exist

RuntimeError: module 'torch.library' has no attribute 'register_fake'

其他依赖:

pip install numpy==1.26.4
pip install matplotlib==3.10.7
pip install decorator==5.2.1
pip install scipy==1.15.3
pip install attrs==25.4.0
pip install psutil==7.1.0
pip install opencv-python==4.12.0.88
pip install numpy==1.26.4

如果没有libGL.so.1,需安装libGL.so.1的软件包:
yum install mesa-libGL

使用NPU进行推理

infer.py如下:

import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra

np.random.seed(3)

# 清理之前的 Hydra 初始化
GlobalHydra.instance().clear()
# 设置工作目录
os.chdir(os.path.dirname(os.path.abspath(__file__)) or '.')
print(f"Current working directory: {os.getcwd()}")

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
with initialize(config_path="configs/sam2.1", version_base=None):
    if torch.cuda.is_available():
        device = "cuda"
    elif hasattr(torch, "npu") and torch.npu.is_available():
        device = "npu"
    else:
        device = "cpu"
    print(f"using device: {device}")
    predictor = SAM2ImagePredictor(build_sam2('sam2.1_hiera_l', checkpoint, device=device))
    
    image = Image.open('./truck.jpg')
    image = np.array(image.convert("RGB"))
    
    predictor.set_image(image)
    input_point = np.array([[500, 375]])
    input_label = np.array([1])
    print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
    
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )
    
    sorted_ind = np.argsort(scores)[::-1]
    masks = masks[sorted_ind]
    scores = scores[sorted_ind]
    logits = logits[sorted_ind]
    show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
    print(" 推理完成")