该模型是 google/vit-base-patch16-224 在 Bingsu/Human_Action_Recognition 数据集上的微调版本。
其在评估集上取得了以下结果:
视觉Transformer(ViT)是一种Transformer编码器模型(类BERT),在大量图像集合上以监督方式进行预训练,即ImageNet-21k,分辨率为224x224像素。随后,该模型在ImageNet(也称为ILSVRC2012)上进行微调,该数据集包含100万张图像和1000个类别,分辨率同样为224x224。
图像以固定大小的补丁序列(分辨率16x16)形式输入模型,并进行线性嵌入。为用于分类任务,还会在序列开头添加一个[CLS]标记。在将序列输入Transformer编码器各层之前,还会添加绝对位置嵌入。
通过预训练模型,其学习到图像的内部表示,进而可用于提取对下游任务有用的特征:例如,若有带标签的图像数据集,可在预训练编码器之上放置一个线性层来训练标准分类器。通常将线性层置于[CLS]标记之上,因为该标记的最后隐藏状态可视为整个图像的表示。
您可以将该模型用于图像分类。
以下是使用此模型将人类动作图像分类为以下类别之一的方法:
calling, clapping, cycling, dancing, drinking, eating, fighting, hugging, laughing, listening_to_music, running, sitting, sleeping, texting, using_laptop
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import numpy as np
import requests
import torch
import torch_npu
import os
import argparse
from openmind import pipeline, is_torch_npu_available
if is_torch_npu_available():
device = "npu:0"
else:
device = "cpu"
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="./")
args = parser.parse_args()
model_path = args.model_name_or_path
image = Image.open("./pexels-photo-175658.jpg")
tensor_img = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()/255.0
processor = ViTImageProcessor.from_pretrained('./')
model = ViTForImageClassification.from_pretrained(model_path, torch_dtype=torch.float16)
model = model.to(device)
inputs = processor(images=tensor_img.npu(), return_tensors="pt")
inputs = inputs.to(device)
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])需要更多信息
训练过程中使用了以下超参数:
| 训练损失 | 轮次 | 步数 | 验证损失 | 准确率 |
|---|---|---|---|---|
| 2.6396 | 0.99 | 39 | 2.0436 | 0.4425 |
| 1.4579 | 2.0 | 79 | 0.7553 | 0.7917 |
| 0.8342 | 2.99 | 118 | 0.5296 | 0.8417 |
| 0.6649 | 4.0 | 158 | 0.4978 | 0.8496 |
| 0.6137 | 4.99 | 197 | 0.4460 | 0.8595 |
| 0.5374 | 6.0 | 237 | 0.4356 | 0.8627 |
| 0.514 | 6.99 | 276 | 0.4349 | 0.8615 |
| 0.475 | 8.0 | 316 | 0.4005 | 0.8786 |
| 0.4663 | 8.99 | 355 | 0.4164 | 0.8659 |
| 0.4178 | 10.0 | 395 | 0.4128 | 0.8738 |
| 0.4226 | 10.99 | 434 | 0.4115 | 0.8690 |
| 0.3896 | 12.0 | 474 | 0.4112 | 0.875 |
| 0.3866 | 12.99 | 513 | 0.4072 | 0.8714 |
| 0.3632 | 14.0 | 553 | 0.4106 | 0.8718 |
| 0.3596 | 14.99 | 592 | 0.4043 | 0.8714 |
| 0.3421 | 16.0 | 632 | 0.4128 | 0.8675 |
| 0.344 | 16.99 | 671 | 0.4181 | 0.8643 |
| 0.3447 | 18.0 | 711 | 0.4128 | 0.8687 |
| 0.3407 | 18.99 | 750 | 0.4097 | 0.8714 |
| 0.3267 | 19.75 | 780 | 0.4097 | 0.8683 |