Ascend-SACT/InfoGraph
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

InfoGraph 模型迁移适配指导

1. 模型概述

InfoGraph 是一个用于无监督图级别表示学习的模型,通过互信息最大化(InfoMax)来学习分子图的嵌入表示。模型包含两个编码器:一个编码整个图,另一个编码不同粒度的子结构(如节点、边),通过对比损失函数最大化两者之间的互信息。编码器采用 GIN(Graph Isomorphism Network)和 NNConv 等图卷积层,可捕获分子拓扑结构和特征信息。InfoGraph 主要用于分子表示的预训练,也可直接用于图分类等下游任务。本文描述的 InfoGraph 模型是基于 DeepChem 套件实现的,后续适配也是基于该套件修改。

2. 准备运行环境

2.1 软件环境

组件版本
Python3.10.19
PyTorch2.1.0
torch_npu2.1.0.post13
CANN8.1.RC1

2.2 硬件环境

设备型号NPU 配置
Atlas 800T A2单卡 / 多卡

2.3 准备镜像

镜像环境镜像地址
公网swr.cn-southwest-2.myhuaweicloud.com/atelier/pytorch_2_1_ascend:pytorch_2.1.0-cann_8.1.rc1-py_3.10-euler_2.10.11-aarch64-snt9b-20250603154214-4e60e43

2.4 启动镜像

docker run -u root --privileged \
 --name {container_name} \
 --device /dev/davinci0 \
 --device /dev/davinci_manager \
 --device /dev/devmm_svm \
 --device /dev/hisi_hdc \
 -v /usr/local/dcmi:/usr/local/dcmi \
 -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
 -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
 -v /etc/ascend_install.info:/etc/ascend_install.info \
 -itd {image_id} /bin/bash

3 运行指导

3.1 安装已完成迁移的deepchem-ascend

docker exec -it {container_name} bash

直接安装 deepchem-ascend 二进制包,该包已基于适配代码重新编译并上传至 PyPI。

pip install deepchem-ascend==0.0.1

3.2 安装依赖

安装运行必要的依赖库

pip install torch-geometric

3.3 测试验证

测试代码举例如下:

import numpy as np
import torch
from deepchem.feat.molecule_featurizers import MolGraphConvFeaturizer
from deepchem.feat.graph_data import BatchGraphData
from deepchem.models.torch_models.torch_model import is_npu_available

try:
    from deepchem.models.torch_models.infograph import InfoGraphModel, InfoGraphStarModel
except ImportError:
    print("Failed to import InfoGraph, please ensure torch-geometric is installed")
    exit(1)

if not is_npu_available():
    print("NPU not available, skipping test")
    exit()

device = torch.device('npu')
print(f"Testing InfoGraph on NPU device: {device}")

num_feat = 30
edge_dim = 11
embedding_dim = 64
num_gc_layers = 2

smiles = ['C1=CC=CC=C1', 'C1=CC=CC=C1C2=CC=CC=C2']
featurizer = MolGraphConvFeaturizer(use_edges=True)
graphs = featurizer.featurize(smiles)
batch_graphs = BatchGraphData(graphs)

print(f"Number of graphs: {len(graphs)}")

print("Testing InfoGraphModel (pretraining task)...")
infograph_model = InfoGraphModel(
    num_features=num_feat,
    embedding_dim=embedding_dim,
    num_gc_layers=num_gc_layers,
    task='pretraining',
    device=device
)

print(f"InfoGraphModel device: {infograph_model.device}")
inputs = batch_graphs.numpy_to_torch(device)
g_enc, l_enc = infograph_model.model(inputs)
print(f"InfoGraphModel forward pass successful, g_enc shape: {g_enc.shape}, l_enc shape: {l_enc.shape}")

print("\nTesting InfoGraphModel (regression task)...")
regression_model = InfoGraphModel(
    num_features=num_feat,
    embedding_dim=embedding_dim,
    num_gc_layers=num_gc_layers,
    task='regression',
    n_tasks=1,
    device=device
)
print(f"Regression model device: {regression_model.device}")
inputs = batch_graphs.numpy_to_torch(device)
regression_output = regression_model.model(inputs)
print(f"Regression model forward pass successful, output shape: {regression_output.shape}")

print("\nTesting InfoGraphModel (classification task)...")
classification_model = InfoGraphModel(
    num_features=num_feat,
    embedding_dim=embedding_dim,
    num_gc_layers=num_gc_layers,
    task='classification',
    n_tasks=1,
    n_classes=2,
    device=device
)
print(f"Classification model device: {classification_model.device}")
inputs = batch_graphs.numpy_to_torch(device)
classification_output = classification_model.model(inputs)
print(f"Classification model forward pass successful, output shape: {classification_output.shape}")

测试结果: image