Ascend-SACT/ProtBert
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

ProtBERT 模型迁移适配指导

1. 模型概述

ProtBERT 是基于 BERT 架构的蛋白质序列深度学习模型,可用于蛋白质序列的特征提取、MLM 预训练和分类任务。该模型在药物发现、材料科学和计算化学等任务中展现出强大能力,本文描述的 ProtBERT 模型是基于 DeepChem 套件实现的,后续适配也是基于该套件修改。

2. 准备运行环境

2.1 软件环境

组件版本
Python3.10.19
PyTorch2.1.0
torch_npu2.1.0.post13
CANN8.1.RC1

2.2 硬件环境

设备型号NPU 配置
Atlas 800T A2单卡 / 多卡

2.3 准备镜像

镜像环境镜像地址
公网swr.cn-southwest-2.myhuaweicloud.com/atelier/pytorch_2_1_ascend:pytorch_2.1.0-cann_8.1.rc1-py_3.10-euler_2.10.11-aarch64-snt9b-20250603154214-4e60e43

2.4 启动镜像

docker run -u root --privileged \
 --name {container_name} \
 --device /dev/davinci0 \
 --device /dev/davinci_manager \
 --device /dev/devmm_svm \
 --device /dev/hisi_hdc \
 -v /usr/local/dcmi:/usr/local/dcmi \
 -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
 -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
 -v /etc/ascend_install.info:/etc/ascend_install.info \
 -itd {image_id} /bin/bash

3 运行指导

3.1 安装已完成迁移的deepchem-ascend

docker exec -it {container_name} bash

直接安装 deepchem-ascend 二进制包,该包已基于适配代码重新编译并上传至 PyPI。

pip install deepchem-ascend==0.0.1

3.2 安装依赖

安装运行必要的依赖库

pip install torch-geometric

3.3 测试验证

测试代码举例如下:

import torch
from deepchem.models.torch_models.prot_bert import ProtBERT
from deepchem.models.torch_models.torch_model import is_npu_available

if not is_npu_available():
    print("NPU not available, skipping test")

protein = "M G L P V S W A P P A L W V L G C C A L L L S L W A"
model_path = 'Rostlab/prot_bert'

device = torch.device('npu')
protbert_feature_extractor = ProtBERT(
    task='feature_extractor',
    model_path=model_path,
    n_tasks=1,
    batch_size=1
)

protbert_feature_extractor.device = device
protbert_feature_extractor.model = protbert_feature_extractor.model.to(device)

model_device = next(protbert_feature_extractor.model.parameters()).device
print(f"模型所在设备: {model_device}")

tokenized_data = protbert_feature_extractor.tokenizer(protein, return_tensors='pt')

input_ids = tokenized_data['input_ids'].to(device)
attention_mask = tokenized_data['attention_mask'].to(device)

protbert_feats = protbert_feature_extractor.get_last_hidden_state(input_ids, attention_mask)
print(f"Feature extraction successful on NPU, output shape: {protbert_feats.shape}")

测试结果: image