该模型是在自定义合同数据集上对 xlm-roberta-base 进行微调得到的版本。 其在评估集上取得了以下结果:
训练过程中使用了以下超参数:
from openmind import AutoModelForSequenceClassification,AutoTokenizer, AutoModel, is_torch_npu_available
from openmind_hub import snapshot_download
import torch
import argparse
import torch.nn.functional as F
import time
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to model",
default="zhouhui/employment-contract-ner-da",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
model_path = args.model_name_or_path
if is_torch_npu_available():
device = "npu:0"
else:
device = "cpu"
#device = "cpu"
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing."
hypothesis = "The movie was good."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)
end_time = time.time()
print(f"硬件环境:{device},推理执行时间:{end_time - start_time}秒")
if __name__ == "__main__":
main()| 训练损失 | 轮次 | 步数 | 验证损失 | 微平均F1值 |
|---|---|---|---|---|
| 0.8971 | 0.24 | 200 | 0.0205 | 0.0 |
| 0.0173 | 0.48 | 400 | 0.0100 | 0.2921 |
| 0.0092 | 0.73 | 600 | 0.0065 | 0.7147 |
| 0.0063 | 0.97 | 800 | 0.0046 | 0.8332 |
| 0.0047 | 1.21 | 1000 | 0.0047 | 0.8459 |
| 0.0042 | 1.45 | 1200 | 0.0039 | 0.8694 |
| 0.0037 | 1.69 | 1400 | 0.0035 | 0.8888 |
| 0.0032 | 1.93 | 1600 | 0.0035 | 0.8840 |
| 0.0025 | 2.18 | 1800 | 0.0029 | 0.8943 |
| 0.0023 | 2.42 | 2000 | 0.0024 | 0.9104 |
| 0.0023 | 2.66 | 2200 | 0.0032 | 0.8808 |
| 0.0021 | 2.9 | 2400 | 0.0022 | 0.9338 |
| 0.0018 | 3.14 | 2600 | 0.0020 | 0.9315 |
| 0.0015 | 3.39 | 2800 | 0.0026 | 0.9297 |