HuggingFace镜像/deit-base-distilled-patch16-224
模型介绍文件和版本分析
下载使用量0

由于此模型是蒸馏版 ViT 模型,您可以将其接入 DeiTModel、DeiTForImageClassification 或 DeiTForImageClassificationWithTeacher。请注意,该模型要求使用 DeiTFeatureExtractor 来准备数据。这里我们使用 AutoFeatureExtractor,它会根据模型名称自动选用合适的特征提取器。


from transformers import  DeiTForImageClassificationWithTeacher
from openmind import AutoFeatureExtractor
from openmind_hub import snapshot_download
from PIL import Image
import requests
from openmind import is_torch_npu_available
import argparse

def parse_args(): 
	parser = argparse.ArgumentParser() 
	parser.add_argument( "--model_name_or_path", type=str, help="Path to model", default=None) 
	args = parser.parse_args() 
	return args
def main():
    device = "npu" if is_torch_npu_available()  else  "cpu"
    #odel_path = args.model_name_or_path
    args = parse_args()
    if args.model_name_or_path:
	    model_path = args.model_name_or_path
    else:
	    model_path = snapshot_download(
		"ChongqingAscend/deit-base-distilled-patch16-224",
		revision="main",
		resume_download=True,
		ignore_patterns=["*.h5", "*.ot", "*.msgpack"]
	)
    url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
    image = Image.open(requests.get(url, stream=True).raw)
    
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
    model = DeiTForImageClassificationWithTeacher.from_pretrained(model_path).to(device)
    
    inputs = feature_extractor(images=image, return_tensors="pt").to(device)
    
    # forward pass
    outputs = model(**inputs)
    logits = outputs.logits
    
    # model predicts one of the 1000 ImageNet classes
    predicted_class_idx = logits.argmax(-1).item()
    print("Predicted class:", model.config.id2label[predicted_class_idx])

if __name__ == "__main__":
    main()