Segment Anything Model 2(SAM 2)是一个用于解决图像和视频中可提示视觉分割问题的基础模型。该模型将SAM扩展到视频领域,将图像视为单帧视频。它采用简洁的Transformer架构,并利用流式内存实现实时视频处理。
使用SAM 2前需要先安装它。该代码需要python>=3.10,torch>=2.5.1,torchvision>=0.20.1。可以使用以下命令在机器上安装SAM 2:
Conda创建环境并激活
conda create -n sam2 python=3.10.0
conda activate sam2git clone https://github.com/facebookresearch/sam2.git && cd sam2
pip install -e .cd checkpoints && \
./download_ckpts.sh && \
cd ..将配置文件拷贝到项目根目录:
cp -r sam2/configs/ ./可通过wheel格式的二进制软件包直接安装。参考文档: https://www.hiascend.com/document/detail/zh/Pytorch/710/configandinstg/instg/insg_0004.html
此外,还需依赖torchvision,其版本需与torch版本对应,参考: https://www.hiascend.com/document/detail/zh/Pytorch/710/configandinstg/instg/insg_0010.html
若版本不对应,将会出现如下报错:
RuntimeError: operator torchvision::nms does not exist
RuntimeError: module 'torch.library' has no attribute 'register_fake'其他依赖:
pip install numpy==1.26.4
pip install matplotlib==3.10.7
pip install decorator==5.2.1
pip install scipy==1.15.3
pip install attrs==25.4.0
pip install psutil==7.1.0
pip install opencv-python==4.12.0.88
pip install numpy==1.26.4
如果没有libGL.so.1,需安装libGL.so.1的软件包:
yum install mesa-libGLinfer.py如下:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
np.random.seed(3)
# 清理之前的 Hydra 初始化
GlobalHydra.instance().clear()
# 设置工作目录
os.chdir(os.path.dirname(os.path.abspath(__file__)) or '.')
print(f"Current working directory: {os.getcwd()}")
def show_mask(mask, ax, random_color=False, borders = True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca(), borders=borders)
if point_coords is not None:
assert input_labels is not None
show_points(point_coords, input_labels, plt.gca())
if box_coords is not None:
# boxes
show_box(box_coords, plt.gca())
if len(scores) > 1:
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
with initialize(config_path="configs/sam2.1", version_base=None):
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch, "npu") and torch.npu.is_available():
device = "npu"
else:
device = "cpu"
print(f"using device: {device}")
predictor = SAM2ImagePredictor(build_sam2('sam2.1_hiera_l', checkpoint, device=device))
image = Image.open('./truck.jpg')
image = np.array(image.convert("RGB"))
predictor.set_image(image)
input_point = np.array([[500, 375]])
input_label = np.array([1])
print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
print(" 推理完成")