[ PROMPT_NODE_27476 ]
torch-geometric
[ SKILL_DOCUMENTATION ]
# PyTorch Geometric (PyG)
## 概述
PyTorch Geometric 是一个基于 PyTorch 的库,用于开发和训练图神经网络 (GNN)。应用此技能进行图数据和不规则结构的深度学习,包括小批量处理、多 GPU 训练和几何深度学习应用。
## 何时使用此技能
当处理以下任务时应使用此技能:
- **基于图的机器学习**:节点分类、图分类、链接预测
- **分子属性预测**:药物发现、化学属性预测
- **社交网络分析**:社区检测、影响力预测
- **引文网络**:论文分类、推荐系统
- **3D 几何数据**:点云、网格、分子结构
- **异构图**:多类型节点和边(例如知识图谱)
- **大规模图学习**:邻居采样、分布式训练
## 快速入门
### 安装
bash
uv pip install torch_geometric
获取额外依赖(稀疏操作、聚类):
bash
uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
### 基本图创建
python
import torch
from torch_geometric.data import Data
# 创建一个包含 3 个节点的简单图
edge_index = torch.tensor([[0, 1, 1, 2], # 源节点
[1, 0, 2, 1]], dtype=torch.long) # 目标节点
x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # 节点特征
data = Data(x=x, edge_index=edge_index)
print(f"节点数: {data.num_nodes}, 边数: {data.num_edges}")
### 加载基准数据集
python
from torch_geometric.datasets import Planetoid
# 加载 Cora 引文网络
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0] # 获取第一个(也是唯一一个)图
print(f"数据集: {dataset}")
print(f"节点数: {data.num_nodes}, 边数: {data.num_edges}")
print(f"特征数: {data.num_node_features}, 类别数: {dataset.num_classes}")
## 核心概念
### 数据结构
PyG 使用 `torch_geometric.data.Data` 类表示图,具有以下关键属性:
- **`data.x`**: 节点特征矩阵 `[节点数, 节点特征数]`
- **`data.edge_index`**: COO 格式的图连接关系 `[2, 边数]`
- **`data.edge_attr`**: 边特征矩阵 `[边数, 边特征数]` (可选)
- **`data.y`**: 节点或图的目标标签
- **`data.pos`**: 节点空间位置