该模型基于蛋白质和抗体序列,通过掩码语言模型(MLM)目标进行预训练。它在论文《Large scale paired antibody language models》(https://arxiv.org/abs/2403.17889)中首次提出。
该模型由IgBert-unpaired微调而来,使用了来自Observed Antibody Space的配对抗体序列。
可通过transformers库加载模型和分词器。
from transformers import BertModel, BertTokenizer
import torch
import torch_npu
import os
import argparse
from openmind import pipeline, is_torch_npu_available
from openmind_hub import snapshot_download
if is_torch_npu_available():
device = "npu:0"
else:
device = "cpu"
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="./")
args = parser.parse_args()
if args.model_name_or_path:
model_path = args.model_name_or_path
else:
model_path = snapshot_download(
"CICC/IgBert",
revision="main",
resume_download=True,
ignore_patterns=["*.h5", "*.ot", " *.msgpack"]
)
tokeniser = BertTokenizer.from_pretrained("./", do_lower_case=False)
model = BertModel.from_pretrained(model_path, add_pooling_layer=False, torch_dtype=torch.float16)分词器用于准备批量输入
# heavy chain sequences
sequences_heavy = [
"VQLAQSGSELRKPGASVKVSCDTSGHSFTSNAIHWVRQAPGQGLEWMGWINTDTGTPTYAQGFTGRFVFSLDTSARTAYLQISSLKADDTAVFYCARERDYSDYFFDYWGQGTLVTVSS",
"QVQLVESGGGVVQPGRSLRLSCAASGFTFSNYAMYWVRQAPGKGLEWVAVISYDGSNKYYADSVKGRFTISRDNSKNTLYLQMNSLRTEDTAVYYCASGSDYGDYLLVYWGQGTLVTVSS"
]
# light chain sequences
sequences_light = [
"EVVMTQSPASLSVSPGERATLSCRARASLGISTDLAWYQQRPGQAPRLLIYGASTRATGIPARFSGSGSGTEFTLTISSLQSEDSAVYYCQQYSNWPLTFGGGTKVEIK",
"ALTQPASVSGSPGQSITISCTGTSSDVGGYNYVSWYQQHPGKAPKLMIYDVSKRPSGVSNRFSGSKSGNTASLTISGLQSEDEADYYCNSLTSISTWVFGGGTKLTVL"
]
# The tokeniser expects input of the form ["V Q ... S S [SEP] E V ... I K", ...]
paired_sequences = []
for sequence_heavy, sequence_light in zip(sequences_heavy, sequences_light):
paired_sequences.append(' '.join(sequence_heavy)+' [SEP] '+' '.join(sequence_light))
tokens = tokeniser.batch_encode_plus(
paired_sequences,
add_special_tokens=True,
pad_to_max_length=True,
return_tensors="pt",
return_special_tokens_mask=True
)
model = model.to(device)请注意,分词器会在每个配对序列的开头添加一个 [CLS] 标记,在每个配对序列的结尾添加一个 [SEP] 标记,并使用 [PAD] 标记进行填充。例如,包含序列 V Q L [SEP] E V V、Q V [SEP] A L 的批次会被分词为 [CLS] V Q L [SEP] E V V [SEP] 和 [CLS] Q V [SEP] A L [SEP] [PAD] [PAD]。
通过将标记输入模型来生成序列嵌入。
output = model(
input_ids=tokens['input_ids'].to(device),
attention_mask=tokens['attention_mask'].to(device)
)
residue_embeddings = output.last_hidden_state为了获得序列表示,可以像这样对残基标记进行平均处理。
# mask special tokens before summing over embeddings
residue_embeddings[tokens["special_tokens_mask"] == 1] = 0
sequence_embeddings_sum = residue_embeddings.sum(1)
# average embedding by dividing sum by sequence lengths
sequence_lengths = torch.sum(tokens["special_tokens_mask"] == 0, dim=1)
sequence_lengths = sequence_lengths.to(device)
sequence_embeddings = sequence_embeddings_sum / sequence_lengths.unsqueeze(1)对于序列级微调,可通过设置add_pooling_layer=True加载带有池化头的模型,并在下游任务中使用output.pooler_output。