该模型是 roberta-large 在 glue 数据集上的微调版本。 它在评估集上取得了以下结果:
from openmind import AutoTokenizer, AutoModel, is_torch_npu_available
from openmind_hub import snapshot_download
import torch
import argparse
import torch.nn.functional as F
import os
import time
# 均值池化 - 考虑注意力掩码以进行正确的平均
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # model_output的第一个元素包含所有token嵌入
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to model",
default="zhouhui/roberta-large-sst2",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
model_path = args.model_name_or_path
if is_torch_npu_available():
device = "npu:0"
else:
device = "cpu"
#device = "cpu"
start_time = time.time()
# 我们想要获取句子嵌入的句子
sentences = ['This is an example sentence', 'Each sentence is converted']
# 从openmind_hub加载模型
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path).to(device)
#model = AutoModel.from_pretrained(model_path).to("cpu")
# 对句子进行分词
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
#encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to("cpu")
# 计算token嵌入
with torch.no_grad():
model_output = model(**encoded_input)
# 执行池化
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# 归一化嵌入
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
end_time = time.time()
print("Sentence embeddings:")
print(sentence_embeddings)
time_taken = end_time - start_time
print(f"硬件环境:{device},推理执行时间:{time_taken}秒")
# print(f"硬件环境:cpu,推理执行时间:{time_taken}秒")
if __name__ == "__main__":
main()需要更多信息
需要更多信息
需要更多信息
训练过程中使用了以下超参数:
| 训练损失 | 轮次 | 步数 | 验证损失 | 准确率 |
|---|---|---|---|---|
| 0.3688 | 1.0 | 264 | 0.1444 | 0.9564 |
| 0.1529 | 2.0 | 528 | 0.1502 | 0.9518 |
| 0.107 | 3.0 | 792 | 0.1388 | 0.9530 |
| 0.0666 | 4.0 | 1056 | 0.1400 | 0.9644 |