一、图论与图表示基础
1. 图的基本概念
# 图的数学定义 G = (V, E) # V: 节点集合 (vertices/nodes) # E: 边集合 (edges)2. 图的表示方式
# 方式1: 邻接矩阵 (Adjacency Matrix) # 适合稠密图,但晶体通常是稀疏的 adj_matrix = torch.tensor([ [0, 1, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1], [0, 1, 1, 0] ]) # 方式2: 边列表 (Edge List) - GNN常用 # 更节省内存,适合稀疏图 edge_index = torch.tensor([ [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], # 源节点 [1, 2, 0, 2, 3, 0, 1, 3, 1, 2] # 目标节点 ]) # 方式3: 带属性的图 node_features = torch.tensor([...]) # 节点特征 (原子类型、电荷等) edge_features = torch.tensor([...]) # 边特征 (键长、键类型等)3. 晶体的图表示
## 晶体特殊性: 周期性边界条件 class CrystalGraph: def __init__(self, atoms, lattice, cutoff=5.0): """ atoms: 原子列表 [(元素, 坐标), ...] lattice: 晶格矩阵 3x3 cutoff: 截断半径,超过此距离不建边 """ self.atoms = atoms self.lattice = lattice self.cutoff = cutoff # 构建图 self.node_features = self._get_atom_features() self.edge_index, self.edge_features = self._build_edges() def _build_edges(self): """ 考虑周期性边界条件建边 需要考虑相邻晶胞中的原子 """ edges = [] edge_attrs = [] # 遍历所有原子对 for i, (elem_i, pos_i) in enumerate(self.atoms): for j, (elem_j, pos_j) in enumerate(self.atoms): # 考虑周期性镜像 for image in self._get_periodic_images(): pos_j_image = pos_j + image @ self.lattice distance = np.linalg.norm(pos_i - pos_j_image) if 0 < distance < self.cutoff: edges.append([i, j]) edge_attrs.append([distance]) # 边特征:距离 return torch.tensor(edges).T, torch.tensor(edge_attrs)二、图神经网络基础
1. 消息传递范式 (Message Passing)
这是所有GNN的统一框架:
# 核心思想: 节点通过聚合邻居信息来更新自己 def message_passing_layer(node_features, edge_index, edge_features): """ 1. Message: 每条边生成一个消息 2. Aggregate: 每个节点聚合收到的所有消息 3. Update: 用聚合后的消息更新节点特征 """ src, dst = edge_index # 源节点和目标节点 # 1. 生成消息 messages = message_function( node_features[src], # 源节点特征 node_features[dst], # 目标节点特征 edge_features # 边特征 ) # 2. 聚合消息 (sum/mean/max) aggregated = scatter_add(messages, dst, dim=0) # 3. 更新节点 new_features = update_function(node_features, aggregated) return new_features2. 图卷积网络 (GCN)
import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GCNConv class GCN(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim): super().__init__() self.conv1 = GCNConv(in_dim, hidden_dim) self.conv2 = GCNConv(hidden_dim, hidden_dim) self.conv3 = GCNConv(hidden_dim, out_dim) def forward(self, x, edge_index): # x: 节点特征 [num_nodes, in_dim] # edge_index: 边 [2, num_edges] x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = self.conv3(x, edge_index) return x # GCN的数学形式 # H^(l+1) = σ(D^(-1/2) A D^(-1/2) H^(l) W^(l)) # A: 邻接矩阵 + 自环 # D: 度矩阵 # H: 节点特征矩阵 # W: 可学习权重3. 图注意力网络 (GAT)
from torch_geometric.nn import GATConv class GAT(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, heads=4): super().__init__() self.conv1 = GATConv(in_dim, hidden_dim, heads=heads) self.conv2 = GATConv(hidden_dim * heads, out_dim, heads=1) def forward(self, x, edge_index): x = F.elu(self.conv1(x, edge_index)) x = self.conv2(x, edge_index) return x # GAT的核心: 学习邻居的重要性权重 # α_ij = softmax_j(LeakyReLU(a^T [Wh_i || Wh_j])) # h'_i = σ(Σ_j α_ij W h_j)4. 边特征的处理
晶体GNN需要处理边特征(如原子间距离):
from torch_geometric.nn import NNConv class EdgeConditionedConv(nn.Module): """ 边特征条件化的图卷积 消息函数依赖于边特征 """ def __init__(self, in_dim, out_dim, edge_dim): super().__init__() # 边特征 → 权重矩阵 self.edge_nn = nn.Sequential( nn.Linear(edge_dim, in_dim * out_dim), nn.ReLU() ) self.conv = NNConv(in_dim, out_dim, self.edge_nn) def forward(self, x, edge_index, edge_attr): return self.conv(x, edge_index, edge_attr)三、几何深度学习:不变性与等变性
这是晶体GNN最重要的理论基础!
1. 为什么需要几何约束?
物理世界的对称性: 1. 平移不变性: 整体移动晶体,性质不变 2. 旋转不变性: 旋转晶体,性质不变 3. 排列不变性: 原子编号顺序不影响性质 好的晶体模型必须尊重这些对称性!2. 如何实现不变性?
# 方法1: 使用不变特征 # 距离是旋转不变的! def invariant_edge_features(pos_i, pos_j): distance = torch.norm(pos_j - pos_i) # 标量,旋转不变 return distance # 方法2: 使用相对坐标 + 聚合 # 角度也是不变的 def get_angle(pos_i, pos_j, pos_k): vec_ij = pos_j - pos_i vec_ik = pos_k - pos_i cos_angle = (vec_ij @ vec_ik) / (norm(vec_ij) * norm(vec_ik)) return cos_angle3. 等变神经网络
# 等变网络保证: 旋转输入 → 输出也相应旋转 class EquivariantLayer(nn.Module): """ 处理标量和向量的等变层 """ def __init__(self, scalar_dim, vector_dim): super().__init__() self.scalar_net = nn.Linear(scalar_dim, scalar_dim) self.vector_net = nn.Linear(vector_dim, vector_dim) self.mix = nn.Linear(scalar_dim + vector_dim, scalar_dim) def forward(self, scalars, vectors): # scalars: [N, scalar_dim] - 不变量 # vectors: [N, 3, vector_dim] - 等变量 # 标量更新(可以用向量的模) vector_norms = vectors.norm(dim=1) # [N, vector_dim] scalar_update = self.mix(torch.cat([scalars, vector_norms], dim=-1)) new_scalars = scalars + self.scalar_net(scalar_update) # 向量更新(保持方向,只改变幅度) # 不能用任意变换,否则破坏等变性 vector_scale = self.vector_net(vector_norms).unsqueeze(1) # [N, 1, vector_dim] new_vectors = vectors * vector_scale return new_scalars, new_vectors四、球谐函数与角度信息
高级晶体GNN(如DimeNet、GemNet)使用球谐函数编码方向信息。
1. 为什么需要球谐函数?
问题: 如何表示3D方向,同时保持旋转等变性? 普通做法: 用(x, y, z)坐标 问题: 坐标依赖参考系,旋转后会变 球谐函数: 在球面上的"傅里叶基" 优点: 天然具有旋转等变性2.球谐函数基础
import e3nn from e3nn import o3 # 球谐函数 Y_l^m(θ, φ) # l: 角动量量子数 (0, 1, 2, ...) # m: 磁量子数 (-l, ..., 0, ..., l) # l=0: 1个基函数 (标量,不变) # l=1: 3个基函数 (向量,像x,y,z) # l=2: 5个基函数 (二阶张量) def spherical_harmonics(directions, lmax=2): """ 将方向向量编码为球谐函数 """ # directions: [N, 3] 单位向量 # 输出: [N, (lmax+1)^2] 球谐系数 return o3.spherical_harmonics( l=list(range(lmax + 1)), x=directions, normalize=True ) # 示例 direction = torch.tensor([[1.0, 0.0, 0.0]]) # x方向 sh = spherical_harmonics(direction, lmax=2) # 输出: [1, 9] (1 + 3 + 5 = 9个系数)3. 在GNN中使用球谐函数
class SphericalMessage(nn.Module): """ 使用球谐函数的消息传递 """ def __init__(self, hidden_dim, lmax=2): super().__init__() self.lmax = lmax self.sh_dim = (lmax + 1) ** 2 # 距离编码 self.distance_embedding = nn.Sequential( GaussianBasis(num_gaussians=50), nn.Linear(50, hidden_dim) ) # 球谐系数的权重 self.sh_weight = nn.Linear(hidden_dim, self.sh_dim * hidden_dim) def forward(self, x, pos, edge_index): src, dst = edge_index # 计算边向量 edge_vec = pos[dst] - pos[src] edge_dist = edge_vec.norm(dim=-1, keepdim=True) edge_dir = edge_vec / (edge_dist + 1e-8) # 球谐编码方向 sh = spherical_harmonics(edge_dir, self.lmax) # [E, sh_dim] # 距离编码 dist_emb = self.distance_embedding(edge_dist) # [E, hidden] # 生成消息 weights = self.sh_weight(dist_emb).view(-1, self.sh_dim, hidden_dim) messages = torch.einsum('es,esh->eh', sh, weights) # [E, hidden] # 聚合 return scatter_add(messages, dst, dim=0)五、模型对比与选择
| 模型 | 几何信息 | 计算效率 | 精度 | 适用场景 |
|---|---|---|---|---|
| CGCNN | 距离 | 快 | 中 | 快速筛选 |
| SchNet | 距离 | 快 | 中 | 分子性质 |
| MEGNet | 距离 | 快 | 中 | 晶体性质 |
| DimeNet | 距离+角度 | 中 | 高 | 精确预测 |
| ALIGNN | 距离+角度 | 中 | 高 | 晶体性质 |
| GemNet | 距离+角度+二面角 | 慢 | 很高 | 高精度 |
| M3GNet | 三体 | 中 | 高 | MD模拟 |
| CHGNet | 三体+电荷 | 中 | 很高 | MD模拟+电荷感知 |