ESM-2 是一个基于掩码语言建模目标训练的先进蛋白质模型。它适用于以蛋白质序列为输入的各种任务的微调。
| 硬件名称 | 配置信息 |
|---|---|
| 机器型号 | Atlas800T A3 |
| CPU型号 | HUAWEI Kunpeng 920 |
| 数量 | 4 |
| AI加速芯片型号 | 昇腾910A3 |
| 数量 | 8 |
| Device内存 | 64G |
| 主机内存类型 | DDR4 |
| 数量 | 24 |
| 单条容量 | 64GB |
| 内存插槽数 | 32 |
| 硬盘 | 数量 2 |
| 内部存储类型 | SSD |
| 单硬盘容量 | 3.2TB |
| 软件分类 | 软件对象 | 版本 | 备注 |
|---|---|---|---|
| Ascend HDK | Ascend-A3-hdk-npu-driver*.run Ascend-A3-hdk-npu-firmware*.run | 昇腾驱动、固件 | |
| CANN | Ascend-cann-toolkit*.run Ascend-cann-kernels-A3*.run | 华为针对AI场景推出的异构计算架构 | |
| PyTorch | torch | PyTorch框架 | |
| torch_npu | torch_npu-*.whl | PyTorch Ascend Adapter插件,使昇腾NPU适配PyTorch框架 |
export IMAGE=quay.io/ascend/vllm-ascend:v0.13.0rc1
# 创建并启动容器
docker run --rm \
--name vllm-ascend \
--shm-size=1g \
--privileged \
-v /usr/local/dcmi:/usr/local/dcmi \
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
-v /etc/ascend_install.info:/etc/ascend_install.info \
-v /root/.cache:/root/.cache \
-p 8000:8000 \
-it $IMAGE bashpip install transformers==5.0.0
拉取模型权重
modelscope download --model facebook/esm2_t33_650M_UR50D --local_dir ./model/esm2_t33_650M_UR50D
添加内容:
from pathlib import Path
import sys
import time
_root = Path(__file__).resolve().parents[1]
if str(_root) not in sys.path:
sys.path.insert(0, str(_root))
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
from transformers import AutoTokenizer, EsmForMaskedLM
def get_device() -> str:
try:
npu = getattr(torch, "npu", None)
if npu is not None and npu.is_available():
print("Ascend NPU detected.")
npu_count = npu.device_count()
if npu_count > 1:
try:
import re
import subprocess
output = subprocess.check_output(["npu-smi", "info"], encoding="utf-8")
matches = re.findall(
r"\|\s+\d+\s+\d+\s+.*?\|\s+.*?\|\s+.*?\s+.*?/\s+.*?\s+(\d+)\s*/\s*(\d+)\s+\|",
output,
)
if matches:
usages = [(i, int(m[0])) for i, m in enumerate(matches) if i < npu_count]
if usages:
best_id, min_usage = min(usages, key=lambda x: x[1])
print(f"Selected NPU:{best_id} with least memory usage ({min_usage} MB).")
return f"npu:{best_id}"
except Exception:
pass
return "npu:0"
except ImportError:
pass
if torch.cuda.is_available():
print("NVIDIA CUDA detected.")
return "cuda:0"
print("No hardware acceleration detected, using CPU.")
return "cpu"
def main() -> None:
device = get_device()
print(f"Using device: {device}")
model_id = "facebook/esm2_t33_650M_UR50D"
for local_path in [
"./models/esm2_t33_650M_UR50D",
"./model/esm2_t33_650M_UR50D",
"esm2-650m/models/esm2_t33_650M_UR50D",
"esm2-650m/model/esm2_t33_650M_UR50D",
]:
if Path(local_path).exists():
model_id = local_path
print(f"Using local model from: {model_id}")
break
print(f"Loading tokenizer and ESM-2 650M model: {model_id}...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
if not hasattr(EsmForMaskedLM, "from_pretrained"):
raise ImportError("EsmForMaskedLM.from_pretrained not found.")
dev = torch.device(device)
model = EsmForMaskedLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if device != "cpu" else torch.float32,
low_cpu_mem_usage=True,
)
model = model.to(dev).eval() # type: ignore[call-arg]
except Exception as e:
print(f"\nError loading model: {e}")
print("\n提示: 如果无法连接 Hugging Face,请尝试设置镜像源:")
print("export HF_ENDPOINT=https://hf-mirror.com")
return
# 蛋白质序列示例:在中间位置用 mask 预测氨基酸(ESM 使用 <<mask>>)
mask_tok = tokenizer.mask_token
sequence = f"MKTVRQERLKSIVR{mask_tok}LE" # 短肽,mask 一处
inputs = tokenizer(sequence, return_tensors="pt").to(dev)
print(f"\nInput protein sequence (with mask): {sequence}")
print("Running ESM-2 fill-mask inference...")
start_time = time.time()
with torch.no_grad():
logits = model(**inputs).logits
end_time = time.time()
# 取 mask 位置的 logits,解码 top-5 预测
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
if mask_token_index[0].numel() == 0:
print("No mask token found in input.")
return
batch_idx, seq_idx = mask_token_index[0][0].item(), mask_token_index[1][0].item()
mask_logits = logits[batch_idx, seq_idx, :]
top_k = 5
top_ids = torch.topk(mask_logits, top_k).indices.tolist()
print("\n--- Result (top predicted amino acids at mask position) ---")
for i, tid in enumerate(top_ids, 1):
token = tokenizer.decode([tid])
print(f" {i}. {token}")
print(f"Inference took: {end_time - start_time:.2f} seconds")
print("--------------")
if __name__ == "__main__":
main()python demo.py