news 2026/4/18 8:35:50

晶体图神经网络

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
晶体图神经网络

一、图论与图表示基础

1. 图的基本概念

# 图的数学定义 G = (V, E) # V: 节点集合 (vertices/nodes) # E: 边集合 (edges)

2. 图的表示方式

# 方式1: 邻接矩阵 (Adjacency Matrix) # 适合稠密图,但晶体通常是稀疏的 adj_matrix = torch.tensor([ [0, 1, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1], [0, 1, 1, 0] ]) # 方式2: 边列表 (Edge List) - GNN常用 # 更节省内存,适合稀疏图 edge_index = torch.tensor([ [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], # 源节点 [1, 2, 0, 2, 3, 0, 1, 3, 1, 2] # 目标节点 ]) # 方式3: 带属性的图 node_features = torch.tensor([...]) # 节点特征 (原子类型、电荷等) edge_features = torch.tensor([...]) # 边特征 (键长、键类型等)

3. 晶体的图表示

## 晶体特殊性: 周期性边界条件 class CrystalGraph: def __init__(self, atoms, lattice, cutoff=5.0): """ atoms: 原子列表 [(元素, 坐标), ...] lattice: 晶格矩阵 3x3 cutoff: 截断半径,超过此距离不建边 """ self.atoms = atoms self.lattice = lattice self.cutoff = cutoff # 构建图 self.node_features = self._get_atom_features() self.edge_index, self.edge_features = self._build_edges() def _build_edges(self): """ 考虑周期性边界条件建边 需要考虑相邻晶胞中的原子 """ edges = [] edge_attrs = [] # 遍历所有原子对 for i, (elem_i, pos_i) in enumerate(self.atoms): for j, (elem_j, pos_j) in enumerate(self.atoms): # 考虑周期性镜像 for image in self._get_periodic_images(): pos_j_image = pos_j + image @ self.lattice distance = np.linalg.norm(pos_i - pos_j_image) if 0 < distance < self.cutoff: edges.append([i, j]) edge_attrs.append([distance]) # 边特征:距离 return torch.tensor(edges).T, torch.tensor(edge_attrs)

二、图神经网络基础

1. 消息传递范式 (Message Passing)

这是所有GNN的统一框架:

# 核心思想: 节点通过聚合邻居信息来更新自己 def message_passing_layer(node_features, edge_index, edge_features): """ 1. Message: 每条边生成一个消息 2. Aggregate: 每个节点聚合收到的所有消息 3. Update: 用聚合后的消息更新节点特征 """ src, dst = edge_index # 源节点和目标节点 # 1. 生成消息 messages = message_function( node_features[src], # 源节点特征 node_features[dst], # 目标节点特征 edge_features # 边特征 ) # 2. 聚合消息 (sum/mean/max) aggregated = scatter_add(messages, dst, dim=0) # 3. 更新节点 new_features = update_function(node_features, aggregated) return new_features

2. 图卷积网络 (GCN)

import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GCNConv class GCN(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim): super().__init__() self.conv1 = GCNConv(in_dim, hidden_dim) self.conv2 = GCNConv(hidden_dim, hidden_dim) self.conv3 = GCNConv(hidden_dim, out_dim) def forward(self, x, edge_index): # x: 节点特征 [num_nodes, in_dim] # edge_index: 边 [2, num_edges] x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = self.conv3(x, edge_index) return x # GCN的数学形式 # H^(l+1) = σ(D^(-1/2) A D^(-1/2) H^(l) W^(l)) # A: 邻接矩阵 + 自环 # D: 度矩阵 # H: 节点特征矩阵 # W: 可学习权重

3. 图注意力网络 (GAT)

from torch_geometric.nn import GATConv class GAT(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, heads=4): super().__init__() self.conv1 = GATConv(in_dim, hidden_dim, heads=heads) self.conv2 = GATConv(hidden_dim * heads, out_dim, heads=1) def forward(self, x, edge_index): x = F.elu(self.conv1(x, edge_index)) x = self.conv2(x, edge_index) return x # GAT的核心: 学习邻居的重要性权重 # α_ij = softmax_j(LeakyReLU(a^T [Wh_i || Wh_j])) # h'_i = σ(Σ_j α_ij W h_j)

4. 边特征的处理
晶体GNN需要处理边特征(如原子间距离):

from torch_geometric.nn import NNConv class EdgeConditionedConv(nn.Module): """ 边特征条件化的图卷积 消息函数依赖于边特征 """ def __init__(self, in_dim, out_dim, edge_dim): super().__init__() # 边特征 → 权重矩阵 self.edge_nn = nn.Sequential( nn.Linear(edge_dim, in_dim * out_dim), nn.ReLU() ) self.conv = NNConv(in_dim, out_dim, self.edge_nn) def forward(self, x, edge_index, edge_attr): return self.conv(x, edge_index, edge_attr)

三、几何深度学习:不变性与等变性

这是晶体GNN最重要的理论基础!

1. 为什么需要几何约束?

物理世界的对称性: 1. 平移不变性: 整体移动晶体,性质不变 2. 旋转不变性: 旋转晶体,性质不变 3. 排列不变性: 原子编号顺序不影响性质 好的晶体模型必须尊重这些对称性!

2. 如何实现不变性?

# 方法1: 使用不变特征 # 距离是旋转不变的! def invariant_edge_features(pos_i, pos_j): distance = torch.norm(pos_j - pos_i) # 标量,旋转不变 return distance # 方法2: 使用相对坐标 + 聚合 # 角度也是不变的 def get_angle(pos_i, pos_j, pos_k): vec_ij = pos_j - pos_i vec_ik = pos_k - pos_i cos_angle = (vec_ij @ vec_ik) / (norm(vec_ij) * norm(vec_ik)) return cos_angle

3. 等变神经网络

# 等变网络保证: 旋转输入 → 输出也相应旋转 class EquivariantLayer(nn.Module): """ 处理标量和向量的等变层 """ def __init__(self, scalar_dim, vector_dim): super().__init__() self.scalar_net = nn.Linear(scalar_dim, scalar_dim) self.vector_net = nn.Linear(vector_dim, vector_dim) self.mix = nn.Linear(scalar_dim + vector_dim, scalar_dim) def forward(self, scalars, vectors): # scalars: [N, scalar_dim] - 不变量 # vectors: [N, 3, vector_dim] - 等变量 # 标量更新(可以用向量的模) vector_norms = vectors.norm(dim=1) # [N, vector_dim] scalar_update = self.mix(torch.cat([scalars, vector_norms], dim=-1)) new_scalars = scalars + self.scalar_net(scalar_update) # 向量更新(保持方向,只改变幅度) # 不能用任意变换,否则破坏等变性 vector_scale = self.vector_net(vector_norms).unsqueeze(1) # [N, 1, vector_dim] new_vectors = vectors * vector_scale return new_scalars, new_vectors

四、球谐函数与角度信息

高级晶体GNN(如DimeNet、GemNet)使用球谐函数编码方向信息。

1. 为什么需要球谐函数?

问题: 如何表示3D方向,同时保持旋转等变性? 普通做法: 用(x, y, z)坐标 问题: 坐标依赖参考系,旋转后会变 球谐函数: 在球面上的"傅里叶基" 优点: 天然具有旋转等变性

2.球谐函数基础

import e3nn from e3nn import o3 # 球谐函数 Y_l^m(θ, φ) # l: 角动量量子数 (0, 1, 2, ...) # m: 磁量子数 (-l, ..., 0, ..., l) # l=0: 1个基函数 (标量,不变) # l=1: 3个基函数 (向量,像x,y,z) # l=2: 5个基函数 (二阶张量) def spherical_harmonics(directions, lmax=2): """ 将方向向量编码为球谐函数 """ # directions: [N, 3] 单位向量 # 输出: [N, (lmax+1)^2] 球谐系数 return o3.spherical_harmonics( l=list(range(lmax + 1)), x=directions, normalize=True ) # 示例 direction = torch.tensor([[1.0, 0.0, 0.0]]) # x方向 sh = spherical_harmonics(direction, lmax=2) # 输出: [1, 9] (1 + 3 + 5 = 9个系数)

3. 在GNN中使用球谐函数

class SphericalMessage(nn.Module): """ 使用球谐函数的消息传递 """ def __init__(self, hidden_dim, lmax=2): super().__init__() self.lmax = lmax self.sh_dim = (lmax + 1) ** 2 # 距离编码 self.distance_embedding = nn.Sequential( GaussianBasis(num_gaussians=50), nn.Linear(50, hidden_dim) ) # 球谐系数的权重 self.sh_weight = nn.Linear(hidden_dim, self.sh_dim * hidden_dim) def forward(self, x, pos, edge_index): src, dst = edge_index # 计算边向量 edge_vec = pos[dst] - pos[src] edge_dist = edge_vec.norm(dim=-1, keepdim=True) edge_dir = edge_vec / (edge_dist + 1e-8) # 球谐编码方向 sh = spherical_harmonics(edge_dir, self.lmax) # [E, sh_dim] # 距离编码 dist_emb = self.distance_embedding(edge_dist) # [E, hidden] # 生成消息 weights = self.sh_weight(dist_emb).view(-1, self.sh_dim, hidden_dim) messages = torch.einsum('es,esh->eh', sh, weights) # [E, hidden] # 聚合 return scatter_add(messages, dst, dim=0)

五、模型对比与选择

模型几何信息计算效率精度适用场景
CGCNN距离快速筛选
SchNet距离分子性质
MEGNet距离晶体性质
DimeNet距离+角度精确预测
ALIGNN距离+角度晶体性质
GemNet距离+角度+二面角很高高精度
M3GNet三体MD模拟
CHGNet三体+电荷很高MD模拟+电荷感知
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 9:42:17

国内石油需求峰值延后至2040年,对A股意味着什么?全方位整理油气板块周期股逻辑

国内石油需求峰值延后至2040年,对A股意味着什么? 标签:石油需求峰值|油气板块|天然气|A股能源股|周期股逻辑 一、一个被市场低估的重要变化 前两年,市场对能源板块的主线判断几乎高度一致: “2030年前后石油需求见顶,传统能源进入下行周期。” 但最近,中石油经济技…

作者头像 李华
网站建设 2026/4/17 20:32:37

Hyperf集合操作终极指南:数据处理新境界

还在为复杂的数组操作而烦恼吗&#xff1f;Hyperf集合组件将彻底改变你的数据处理方式&#xff01;作为PHP开发者的得力助手&#xff0c;它提供了超过100个实用的方法&#xff0c;让数组操作变得前所未有的简单和高效。 【免费下载链接】hyperf &#x1f680; A coroutine fram…

作者头像 李华
网站建设 2026/4/18 6:34:23

FlashAttention三大核心技术:如何让大模型推理速度提升5倍

FlashAttention三大核心技术&#xff1a;如何让大模型推理速度提升5倍 【免费下载链接】flash-attention Fast and memory-efficient exact attention 项目地址: https://gitcode.com/GitHub_Trending/fl/flash-attention 大语言模型推理过程中的性能瓶颈一直是困扰开发…

作者头像 李华
网站建设 2026/4/18 6:34:21

Java面试实战:从Spring Boot到微服务架构的全面解析

场景描述 在一家知名互联网大厂的会议室里&#xff0c;面试官李老师正在对一位名叫“超好吃”的Java小白求职者进行面试。此次面试主要涉及电商场景下的技术栈应用。 第一轮提问 李老师&#xff1a; 你能简要谈谈在电商网站中&#xff0c;我们为什么选择Spring Boot来构建后台服…

作者头像 李华