news 2026/5/9 18:14:43

CANN旋转位置编码算子API

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
CANN旋转位置编码算子API

ApplyRotaryPosEmb 算子 API 描述

【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力,涵盖算子生成、算子优化等领域,支撑模型选型、训练效果评估,统一量化评估标准,识别Agent能力短板,构建CANN领域评测平台,推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench

1. 算子简介

对 query 和 key 执行旋转位置编码 (RoPE) 计算。

主要应用场景

  • 大语言模型(LLaMA、GPT-NeoX 等)中的位置编码
  • Transformer 自注意力机制中 query 和 key 的位置信息注入
  • 支持长序列外推的相对位置编码方案

算子特征

  • 难度等级:L2(FusedComposite)
  • 四输入(query、key、cos、sin)双输出(query_out、key_out)
  • 支持 BSND 和 BNSD 两种布局,以及 half 和 interleaved 两种旋转模式

2. 算子定义

数学公式

$$ rotate_half(x) = concat(-x[head_dim/2:], x[:head_dim/2]) $$

$$ y = x \cdot cos + rotate_half(x) \cdot sin $$

其中:

  • 在 half(连续半分)模式下,将最后一维分为前后两半进行旋转
  • 在 interleaved(交错)模式下,取偶数/奇数索引位置交错旋转
  • cos 和 sin 为预计算的位置编码,需要广播到与 query/key 匹配的 shape

3. 接口规范

算子原型

cann_bench.apply_rotary_pos_emb(Tensor query, Tensor key, Tensor cos, Tensor sin, int layout, str rotaryMode) -> (Tensor query_out, Tensor key_out)

输入参数说明

参数类型默认值描述
queryTensor必选查询张量
keyTensor必选键张量
cosTensor必选余弦位置编码
sinTensor必选正弦位置编码
layoutint0输入布局 (0: [B,S,N,D], 1: [B,N,S,D])
rotaryModestring"half"旋转模式 ("half": 连续半分式,"interleaved": 交错式)

输出

参数Shapedtype描述
query_out与输入 query 相同与输入 query 相同旋转后的查询张量
key_out与输入 key 相同与输入 key 相同旋转后的键张量

数据类型

输入 dtype输出 dtype
float32float32
float16float16
bfloat16bfloat16

规则与约束

  • query 和 key 的 shape 必须相同
  • query/key 为 4D 张量:layout=0 时为 (batch_size, seq_len, num_heads, head_dim),layout=1 时为 (batch_size, num_heads, seq_len, head_dim)
  • cos/sin 为 (seq_len, head_dim/2) 或 (batch_size, seq_len, head_dim/2)
  • head_dim 必须为偶数(需要分为两半进行旋转)
  • 所有输入张量的 dtype 必须一致

4. 精度要求

采用生态算子精度标准进行验证。

误差指标

  1. 平均相对误差(MERE):采样点中相对误差平均值

    $$ \text{MERE} = \text{avg}(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)+\text{1e-7}}) $$

  2. 最大相对误差(MARE):采样点中相对误差最大值

    $$ \text{MARE} = \max(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)+\text{1e-7}}) $$

通过标准

数据类型FLOAT16BFLOAT16FLOAT32HiFLOAT32FLOAT8 E4M3FLOAT8 E5M2
通过阈值(Threshold)2^-102^-72^-132^-112^-32^-2

当平均相对误差 MERE < Threshold,最大相对误差 MARE < 10 * Threshold 时判定为通过。

5. 标准 Golden 代码

import torch """ ApplyRotaryPosEmb 算子 Torch Golden 参考实现 对 query 和 key 执行旋转位置编码 (RoPE) 计算 公式: rotate_half(x) = concat(-x[head_dim/2:], x[:head_dim/2]) y = (x * cos) + (rotate_half(x) * sin) 参考: - RoFormer: https://arxiv.org/abs/2104.09864 - LLaMA: https://github.com/meta-llama/llama - HuggingFace transformers: https://huggingface.co/docs/transformers/internal/rope_utils """ def apply_rotary_pos_emb( query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, layout: int = 0, rotaryMode: str = 'half' ) -> tuple[torch.Tensor, torch.Tensor]: """ 对 query 和 key 执行旋转位置编码 (RoPE) 计算 Args: query: 查询张量,shape 为 (B, S, N, D) 或 (B, N, S, D) key: 键张量,shape 同 query cos: 余弦位置编码,shape 为 (S, D/2) 或 (B, S, D/2) sin: 正弦位置编码,shape 同 cos layout: 输入布局 (0: [B,S,N,D], 1: [B,N,S,D]) rotaryMode: 旋转模式 ("half": 连续半分式,"interleaved": 交错式) Returns: query_out: 旋转后的查询张量 key_out: 旋转后的键张量 Examples: >>> B, S, N, D = 2, 4, 8, 128 >>> query = torch.randn(B, S, N, D) >>> key = torch.randn(B, S, N, D) >>> cos = torch.randn(S, D // 2) >>> sin = torch.randn(S, D // 2) >>> q_out, k_out = apply_rotary_pos_emb(query, key, cos, sin) """ def rotate_half(x: torch.Tensor, mode: str) -> torch.Tensor: """ 旋转输入张量的一半维度 Args: x: 输入张量 mode: 旋转模式 Returns: 旋转后的张量 """ if mode == 'interleaved': # GPT-J 风格的交错式旋转 x1 = x[..., ::2] # 取偶数索引 x2 = x[..., 1::2] # 取奇数索引 rotated = torch.stack([-x2, x1], dim=-1).flatten(-2) else: # LLaMA/Meta 风格的连续半分式旋转 half_dim = x.shape[-1] // 2 x1 = x[..., :half_dim] x2 = x[..., half_dim:] rotated = torch.cat([-x2, x1], dim=-1) return rotated def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mode: str) -> torch.Tensor: """ 对单个张量应用 RoPE Args: x: 输入张量 cos: 余弦编码 sin: 正弦编码 mode: 旋转模式 Returns: 旋转后的张量 """ # 调整 cos/sin 的 shape 以匹配输入 # cos/sin: (S, D/2) 或 (B, S, D/2) # 需要扩展到 (B, S, N, D) 或 (B, N, S, D) if cos.dim() == 2: # cos: (S, D/2) -> 需要扩展到 (B, S, 1, D) cos = cos.unsqueeze(0).unsqueeze(2) # (1, S, 1, D/2) sin = sin.unsqueeze(0).unsqueeze(2) elif cos.dim() == 3: # cos: (B, S, D/2) -> 需要扩展到 (B, S, 1, D) cos = cos.unsqueeze(2) # (B, S, 1, D/2) sin = sin.unsqueeze(2) # 如果 layout=1 (B,N,S,D),需要调整 if layout == 1: cos = cos.transpose(1, 2) # (B, 1, S, D/2) sin = sin.transpose(1, 2) # 重复 cos/sin 到完整的 head_dim cos = cos.repeat(1, 1, 1, 2) if cos.dim() == 4 else cos.repeat_interleave(2, dim=-1) sin = sin.repeat(1, 1, 1, 2) if sin.dim() == 4 else sin.repeat_interleave(2, dim=-1) # 应用 RoPE 公式 x_rotate = rotate_half(x, mode) return (x * cos) + (x_rotate * sin) # 对 query 和 key 分别应用 RoPE query_out = apply_rotary(query, cos, sin, rotaryMode) key_out = apply_rotary(key, cos, sin, rotaryMode) return query_out, key_out

6. 额外信息

算子调用示例

import torch import cann_bench B, S, N, D = 2, 128, 32, 128 query = torch.randn(B, S, N, D, dtype=torch.float16, device="npu") key = torch.randn(B, S, N, D, dtype=torch.float16, device="npu") cos = torch.randn(S, D // 2, dtype=torch.float16, device="npu") sin = torch.randn(S, D // 2, dtype=torch.float16, device="npu") # BSND 布局,half 模式 q_out, k_out = cann_bench.apply_rotary_pos_emb(query, key, cos, sin, layout=0, rotaryMode="half") # BNSD 布局,interleaved 模式 query_bnsd = query.transpose(1, 2) key_bnsd = key.transpose(1, 2) q_out, k_out = cann_bench.apply_rotary_pos_emb(query_bnsd, key_bnsd, cos, sin, layout=1, rotaryMode="interleaved")

【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力,涵盖算子生成、算子优化等领域,支撑模型选型、训练效果评估,统一量化评估标准,识别Agent能力短板,构建CANN领域评测平台,推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/9 18:13:43

AI加速MOF碳捕集筛选:GHP-MOFassemble框架性能解析与实战

1. 项目概述&#xff1a;当AI遇见MOF&#xff0c;一场关于碳捕集的“算力革命”最近几年&#xff0c;但凡关注材料科学和气候变化领域的朋友&#xff0c;肯定对“金属有机框架”&#xff08;MOF&#xff09;这个词不陌生。这东西就像一个由金属离子和有机配体搭起来的、结构无限…

作者头像 李华
网站建设 2026/5/9 18:06:41

AI项目管理中的算法偏见与包容性设计:效率与公平的平衡之道

1. 项目概述&#xff1a;当AI遇见项目管理&#xff0c;效率与公平的十字路口干了十几年项目管理&#xff0c;从最初拿着甘特图跟团队死磕&#xff0c;到后来用上各种协作软件&#xff0c;我原以为工具迭代的终点就是流程自动化。直到人工智能&#xff08;AI&#xff09;开始渗透…

作者头像 李华
网站建设 2026/5/9 18:06:40

AnimateDiff高级控制:通过草图引导视频生成

AnimateDiff高级控制&#xff1a;通过草图引导视频生成 1. 引言 你是否曾经遇到过这样的情况&#xff1a;用文字描述想要生成的视频内容&#xff0c;但AI生成的视频总是与你的想象有些差距&#xff1f;或者你想要精确控制视频中物体的运动轨迹和构图&#xff0c;却发现文字描…

作者头像 李华
网站建设 2026/5/9 18:00:56

华为CANN/tensorflow alltoallvc集合通信

alltoallvc 【免费下载链接】tensorflow Ascend TensorFlow Adapter 项目地址: https://gitcode.com/cann/tensorflow 功能说明 集合通信alltoallvc操作接口。向通信域内所有rank发送数据&#xff08;数据量可以定制&#xff09;&#xff0c;并从所有rank接收数据。 a…

作者头像 李华
网站建设 2026/5/9 17:59:19

图神经网络与强化学习融合:复杂网络智能决策实战指南

1. 项目概述&#xff1a;当复杂网络遇见智能决策最近几年&#xff0c;我身边搞生态建模、生物信息分析&#xff0c;甚至做城市规划的朋友&#xff0c;都开始频繁地跟我聊起两个词&#xff1a;图神经网络和强化学习。这让我意识到&#xff0c;一个非常有意思的技术融合正在发生。…

作者头像 李华