由于此模型是蒸馏版 ViT 模型,您可以将其接入 DeiTModel、DeiTForImageClassification 或 DeiTForImageClassificationWithTeacher。请注意,该模型要求使用 DeiTFeatureExtractor 来准备数据。这里我们使用 AutoFeatureExtractor,它会根据模型名称自动选用合适的特征提取器。
from transformers import DeiTForImageClassificationWithTeacher
from openmind import AutoFeatureExtractor
from openmind_hub import snapshot_download
from PIL import Image
import requests
from openmind import is_torch_npu_available
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument( "--model_name_or_path", type=str, help="Path to model", default=None)
args = parser.parse_args()
return args
def main():
device = "npu" if is_torch_npu_available() else "cpu"
#odel_path = args.model_name_or_path
args = parse_args()
if args.model_name_or_path:
model_path = args.model_name_or_path
else:
model_path = snapshot_download(
"ChongqingAscend/deit-base-distilled-patch16-224",
revision="main",
resume_download=True,
ignore_patterns=["*.h5", "*.ot", "*.msgpack"]
)
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
model = DeiTForImageClassificationWithTeacher.from_pretrained(model_path).to(device)
inputs = feature_extractor(images=image, return_tensors="pt").to(device)
# forward pass
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
if __name__ == "__main__":
main()