该模型基于蛋白质和抗体序列,使用掩码语言模型(MLM)目标进行预训练。它在论文《Large scale paired antibody language models》(https://arxiv.org/abs/2403.17889)中被首次提出。
该模型由 ProtBert-BFD 微调而来,使用的非配对抗体序列来源于 Observed Antibody Space。
可通过 transformers 库加载该模型及分词器。
from transformers import BertModel, BertTokenizer
import torch
import torch_npu
import os
import argparse
from openmind import pipeline, is_torch_npu_available
if is_torch_npu_available():
device = "npu:0"
else:
device = "cpu"
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="./")
args = parser.parse_args()
model_path = args.model_name_or_path
tokeniser = BertTokenizer.from_pretrained("./", do_lower_case=False)
model = BertModel.from_pretrained(model_path, add_pooling_layer=False)分词器用于准备批次输入
# single chain sequences
sequences = [
"EVVMTQSPASLSVSPGERATLSCRARASLGISTDLAWYQQRPGQAPRLLIYGASTRATGIPARFSGSGSGTEFTLTISSLQSEDSAVYYCQQYSNWPLTFGGGTKVEIK",
"ALTQPASVSGSPGQSITISCTGTSSDVGGYNYVSWYQQHPGKAPKLMIYDVSKRPSGVSNRFSGSKSGNTASLTISGLQSEDEADYYCNSLTSISTWVFGGGTKLTVL"
]
# The tokeniser expects input of the form ["E V V M...", "A L T Q..."]
sequences = [' '.join(sequence) for sequence in sequences]
tokens = tokeniser.batch_encode_plus(
sequences,
add_special_tokens=True,
pad_to_max_length=True,
return_tensors="pt",
return_special_tokens_mask=True
)
model = model.to(device)请注意,分词器会在每个序列的开头添加一个[CLS]标记,在每个序列的末尾添加一个[SEP]标记,并使用[PAD]标记进行填充。例如,包含序列E V V M、A L的批次会被分词为[CLS] E V V M [SEP]和[CLS] A L [SEP] [PAD] [PAD]。
通过将标记输入模型来生成序列嵌入。
output = model(
input_ids=tokens['input_ids'].to(device),
attention_mask=tokens['attention_mask'].to(device)
)
residue_embeddings = output.last_hidden_state为了获得序列表示,可以像这样对残基标记进行平均处理
import torch
# mask special tokens before summing over embeddings
residue_embeddings[tokens["special_tokens_mask"] == 1] = 0
sequence_embeddings_sum = residue_embeddings.sum(1)
# average embedding by dividing sum by sequence lengths
sequence_lengths = torch.sum(tokens["special_tokens_mask"] == 0, dim=1)
sequence_lengths = sequence_lengths.to(device)
sequence_embeddings = sequence_embeddings_sum / sequence_lengths.unsqueeze(1)对于序列级微调,可通过设置add_pooling_layer=True加载带有池化头的模型,并在下游任务中使用output.pooler_output。