news 2026/4/18 14:10:02

关于transformer的注意力权重可视化

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
关于transformer的注意力权重可视化

可视化

import torch import numpy as np import matplotlib.pyplot as plt import seaborn as sns from typing import Optional, List import os def visualize_attention_distribution( attentions, input_ids, processor, gt_start_frame, gt_end_frame, query_text, video_id: str, save_dir: str = "/home/share/svmd5vm0/home/scut_czy1/attn_map", show_all_layers: bool = True, figsize: tuple = (20, 12), ): """ 可视化query对各帧的注意力分布 Args: attentions: 模型输出的注意力 tuple of (batch, num_heads, seq_len, seq_len) input_ids: 输入token ids processor: tokenizer processor gt_start_frame: 真实起始帧 gt_end_frame: 真实结束帧 query_text: 查询文本 video_id: 视频ID,用于保存文件名 save_dir: 保存目录 show_all_layers: 是否显示所有层的注意力 figsize: 图表大小 """ os.makedirs(save_dir, exist_ok=True) # 1. 获取特殊token的ID vision_start_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_start|>') vision_end_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_end|>') # 2. 定位query token的位置 input_ids_list = input_ids[0].tolist() query = query_text.strip() if query.endswith('.'): query = query[:-1] query_ids = processor.tokenizer(query, add_special_tokens=False)["input_ids"] query_start_idx = None query_end_idx = None for i in range(len(input_ids_list) - len(query_ids) + 1): if input_ids_list[i:i + len(query_ids)] == query_ids: query_start_idx = i query_end_idx = i + len(query_ids) - 1 break if query_start_idx is None: print(f"Warning: Query tokens not found for video {video_id}") return # 3. 定位每一帧的vision token位置 vision_start_indices = [i for i, x in enumerate(input_ids_list) if x == vision_start_token_id] vision_end_indices = [i for i, x in enumerate(input_ids_list) if x == vision_end_token_id] num_frames = len(vision_start_indices) num_layers = len(attentions) if num_frames == 0: print(f"Warning: No vision tokens found for video {video_id}") return gt_end_frame = min(gt_end_frame, num_frames - 1) # 4. 提取每一层、每一帧的注意力分数 # layer_frame_attention: [num_layers, num_frames] layer_frame_attention = [] for layer_idx in range(num_layers): frame_scores = [] layer_attn = attentions[layer_idx][0] # [num_heads, seq_len, seq_len] for frame_idx in range(num_frames): v_start = vision_start_indices[frame_idx] v_end = vision_end_indices[frame_idx] # 提取 query tokens -> 该帧vision tokens 的注意力 query_to_frame_attn = layer_attn[:, query_start_idx:query_end_idx+1, v_start+1:v_end] # 对所有头、query tokens、vision patches取平均 frame_score = query_to_frame_attn.mean().item() frame_scores.append(frame_score) layer_frame_attention.append(frame_scores) layer_frame_attention = np.array(layer_frame_attention) # [num_layers, num_frames] # 5. 计算平均注意力(所有层平均) avg_attention = layer_frame_attention.mean(axis=0) # [num_frames] # 6. 创建可视化 if show_all_layers and num_layers > 1: fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3) # ========== 图1: 所有层的注意力热力图 ========== ax1 = fig.add_subplot(gs[0, :]) im = ax1.imshow(layer_frame_attention, aspect='auto', cmap='YlOrRd', interpolation='nearest') ax1.set_xlabel('Frame Index', fontsize=12, fontweight='bold') ax1.set_ylabel('Layer Index', fontsize=12, fontweight='bold') ax1.set_title(f'Attention Heatmap: Query → Frames (All Layers)\nQuery: "{query}"', fontsize=14, fontweight='bold', pad=20) # 标记目标区域 ax1.axvline(x=gt_start_frame-0.5, color='blue', linestyle='--', linewidth=2, label='GT Start') ax1.axvline(x=gt_end_frame+0.5, color='blue', linestyle='--', linewidth=2, label='GT End') # 添加颜色条 cbar = plt.colorbar(im, ax=ax1) cbar.set_label('Attention Score', fontsize=10, fontweight='bold') ax1.legend(loc='upper right') # ========== 图2: 平均注意力柱状图 ========== ax2 = fig.add_subplot(gs[1, :]) frames = np.arange(num_frames) colors = ['lightcoral' if gt_start_frame <= i <= gt_end_frame else 'lightblue' for i in range(num_frames)] bars = ax2.bar(frames, avg_attention, color=colors, edgecolor='black', linewidth=0.5) # 高亮目标帧 for i in range(gt_start_frame, gt_end_frame + 1): bars[i].set_edgecolor('red') bars[i].set_linewidth(2) ax2.set_xlabel('Frame Index', fontsize=12, fontweight='bold') ax2.set_ylabel('Average Attention Score', fontsize=12, fontweight='bold') ax2.set_title('Average Attention Distribution (All Layers & Heads)', fontsize=14, fontweight='bold', pad=15) ax2.grid(axis='y', alpha=0.3, linestyle='--') # 添加目标区域标注 ax2.axvspan(gt_start_frame-0.5, gt_end_frame+0.5, alpha=0.2, color='red', label=f'GT Frames [{gt_start_frame}, {gt_end_frame}]') ax2.legend(loc='upper right') # ========== 图3: 目标帧 vs 非目标帧的注意力对比 ========== ax3 = fig.add_subplot(gs[2, 0]) target_attention = avg_attention[gt_start_frame:gt_end_frame+1] non_target_mask = np.ones(num_frames, dtype=bool) non_target_mask[gt_start_frame:gt_end_frame+1] = False non_target_attention = avg_attention[non_target_mask] comparison_data = [target_attention, non_target_attention] box = ax3.boxplot(comparison_data, labels=['Target Frames', 'Non-Target Frames'], patch_artist=True, showmeans=True) box['boxes'][0].set_facecolor('lightcoral') box['boxes'][1].set_facecolor('lightblue') ax3.set_ylabel('Attention Score', fontsize=12, fontweight='bold') ax3.set_title('Target vs Non-Target Frames', fontsize=13, fontweight='bold', pad=15) ax3.grid(axis='y', alpha=0.3, linestyle='--') # 添加统计信息 target_mean = target_attention.mean() non_target_mean = non_target_attention.mean() ratio = target_mean / (non_target_mean + 1e-7) stats_text = f'Target Mean: {target_mean:.4f}\n' stats_text += f'Non-Target Mean: {non_target_mean:.4f}\n' stats_text += f'Ratio: {ratio:.2f}x' ax3.text(0.02, 0.98, stats_text, transform=ax3.transAxes, fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) # ========== 图4: 逐层注意力趋势 ========== ax4 = fig.add_subplot(gs[2, 1]) layer_target_mean = [] layer_non_target_mean = [] for layer_idx in range(num_layers): target_mean = layer_frame_attention[layer_idx, gt_start_frame:gt_end_frame+1].mean() non_target_mean = layer_frame_attention[layer_idx, non_target_mask].mean() layer_target_mean.append(target_mean) layer_non_target_mean.append(non_target_mean) layers = np.arange(num_layers) ax4.plot(layers, layer_target_mean, 'o-', color='red', linewidth=2, markersize=6, label='Target Frames') ax4.plot(layers, layer_non_target_mean, 's-', color='blue', linewidth=2, markersize=6, label='Non-Target Frames') ax4.set_xlabel('Layer Index', fontsize=12, fontweight='bold') ax4.set_ylabel('Mean Attention Score', fontsize=12, fontweight='bold') ax4.set_title('Layer-wise Attention Trend', fontsize=13, fontweight='bold', pad=15) ax4.legend(loc='best') ax4.grid(alpha=0.3, linestyle='--') else: # 简化版:只显示平均注意力 fig, ax = plt.subplots(figsize=(12, 6)) frames = np.arange(num_frames) colors = ['lightcoral' if gt_start_frame <= i <= gt_end_frame else 'lightblue' for i in range(num_frames)] bars = ax.bar(frames, avg_attention, color=colors, edgecolor='black', linewidth=0.5) for i in range(gt_start_frame, gt_end_frame + 1): bars[i].set_edgecolor('red') bars[i].set_linewidth(2) ax.set_xlabel('Frame Index', fontsize=12, fontweight='bold') ax.set_ylabel('Average Attention Score', fontsize=12, fontweight='bold') ax.set_title(f'Attention Distribution\nQuery: "{query}"', fontsize=14, fontweight='bold', pad=20) ax.grid(axis='y', alpha=0.3, linestyle='--') ax.axvspan(gt_start_frame-0.5, gt_end_frame+0.5, alpha=0.2, color='red', label=f'GT Frames [{gt_start_frame}, {gt_end_frame}]') ax.legend(loc='upper right') # 7. 保存图表 save_path = os.path.join(save_dir, f"{video_id}_attention_distribution.png") plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Saved attention visualization to: {save_path}") plt.close() # 8. 保存数值数据(CSV) save_data_path = os.path.join(save_dir, f"{video_id}_attention_data.npz") np.savez( save_data_path, layer_frame_attention=layer_frame_attention, avg_attention=avg_attention, gt_start_frame=gt_start_frame, gt_end_frame=gt_end_frame, query=query ) print(f"Saved attention data to: {save_data_path}") # 9. 返回统计信息 target_attention = avg_attention[gt_start_frame:gt_end_frame+1] non_target_mask = np.ones(num_frames, dtype=bool) non_target_mask[gt_start_frame:gt_end_frame+1] = False non_target_attention = avg_attention[non_target_mask] stats = { 'video_id': video_id, 'query': query, 'num_frames': num_frames, 'num_layers': num_layers, 'gt_range': (gt_start_frame, gt_end_frame), 'target_attention_mean': float(target_attention.mean()), 'target_attention_std': float(target_attention.std()), 'non_target_attention_mean': float(non_target_attention.mean()), 'non_target_attention_std': float(non_target_attention.std()), 'attention_ratio': float(target_attention.mean() / (non_target_attention.mean() + 1e-7)), 'attention_concentration': float(target_attention.sum() / avg_attention.sum()), } return stats def batch_visualize_attention( model, processor, data_list: List[dict], save_dir: str = "/home/share/svmd5vm0/home/scut_czy1/attn_map", device: str = "cuda", ): """ 批量处理多个视频的注意力可视化 Args: model: 模型 processor: processor data_list: 数据列表,每个元素包含: - video_path: 视频路径 - query: 查询文本 - start_frame: 起始帧 - end_frame: 结束帧 - video_id: 视频ID save_dir: 保存目录 device: 设备 """ model.eval() all_stats = [] for data in data_list: print(f"\nProcessing video: {data['video_id']}") # 准备输入 messages = [ { "role": "user", "content": [ { "type": "video", "video": data['video_path'], "fps": 1 }, {"type": "text", "text": data['query']}, ], } ] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ) inputs = inputs.to(device) # 前向传播(获取注意力) with torch.no_grad(): outputs = model(**inputs, output_attentions=True) # 可视化 stats = visualize_attention_distribution( attentions=outputs.attentions, input_ids=inputs['input_ids'], processor=processor, gt_start_frame=data['start_frame'], gt_end_frame=data['end_frame'], query_text=data['query'], video_id=data['video_id'], save_dir=save_dir, ) all_stats.append(stats) # 保存所有统计信息 import json stats_path = os.path.join(save_dir, "all_stats.json") with open(stats_path, 'w') as f: json.dump(all_stats, f, indent=4) print(f"\nSaved all statistics to: {stats_path}") return all_stats # ========== 使用示例 ========== if __name__ == "__main__": """ 使用示例 """ # 示例1: 单个视频可视化 from transformers import Qwen3VLForConditionalGeneration, AutoProcessor model = Qwen3VLForConditionalGeneration.from_pretrained( "/home/share/svmd5vm0/home/scut_czy1/Qwen3-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager" ) processor = AutoProcessor.from_pretrained("/home/share/svmd5vm0/home/scut_czy1/Qwen3-VL-2B-Instruct") query_text = "A person is reading a book" # 准备输入 messages = [{ "role": "user", "content": [ {"type": "video", "video": "/home/share/svmd5vm0/home/scut_czy1/datasets/Charadesfps/videos_1FPS/0A8CF.mp4", "fps": 1}, {"type": "text", "text": query_text}, ], }] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ).to("cuda") # 获取输出(带注意力) with torch.no_grad(): outputs = model(**inputs, output_attentions=True) # 可视化 stats = visualize_attention_distribution( attentions=outputs.attentions, input_ids=inputs['input_ids'], processor=processor, gt_start_frame=5, gt_end_frame=9, query_text= query_text, video_id="video_001", save_dir="/home/share/svmd5vm0/home/scut_czy1/attn_map" ) print("Statistics:", stats) # 示例2: 批量处理 """ data_list = [ { 'video_path': 'video1.mp4', 'query': 'person drinking water', 'start_frame': 5, 'end_frame': 9, 'video_id': 'video_001' }, { 'video_path': 'video2.mp4', 'query': 'person opening door', 'start_frame': 10, 'end_frame': 15, 'video_id': 'video_002' }, ] all_stats = batch_visualize_attention( model=model, processor=processor, data_list=data_list, save_dir="./visualizations" ) """ print("可视化工具已准备就绪!")
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/17 13:54:08

PPO强化学习算法详解

PPO强化学习算法详解 一、什么是PPO&#xff1f; PPO是一种策略梯度方法&#xff0c;由OpenAI在2017年提出。它的核心思想是&#xff1a;在更新策略时&#xff0c;不要让新策略偏离旧策略太远&#xff0c;这样训练更稳定。 为什么需要PPO&#xff1f; 传统的策略梯度方法&a…

作者头像 李华
网站建设 2026/4/17 21:10:09

基于CARAFE上采样操作的YOLOv12性能优化实战指南

购买即可解锁300+YOLO优化文章,并且还有海量深度学习复现项目,价格仅需两杯奶茶的钱,别人有的本专栏也有! 文章目录 基于CARAFE上采样操作的YOLOv12性能优化实战指南 性能提升数据实证 CARAFE核心机制解析 完整实现代码 1. CARAFE基础模块实现 2. YOLOv12与CARAFE的集成方…

作者头像 李华
网站建设 2026/4/18 7:43:38

【Azure CLI量子作业状态查询全攻略】:掌握5种高效查询技巧与实战命令

第一章&#xff1a;Azure CLI量子作业状态查询概述Azure CLI 提供了对 Azure Quantum 服务的命令行访问能力&#xff0c;使开发者能够提交量子电路、管理作业以及查询作业执行状态。通过简洁的指令结构&#xff0c;用户可在本地或自动化脚本中高效监控量子计算任务的生命周期。…

作者头像 李华
网站建设 2026/4/18 8:55:32

云原生Agent的Docker批量部署全解析(专家20年实战经验曝光)

第一章&#xff1a;云原生Agent的Docker批量部署概述在现代云原生架构中&#xff0c;自动化部署和管理分布式Agent已成为提升运维效率的核心手段。利用Docker容器化技术&#xff0c;可实现Agent的快速构建、标准化运行环境与跨平台一致性部署。通过集中编排工具与脚本化流程&am…

作者头像 李华
网站建设 2026/4/18 8:48:55

如何让Cirq智能补全100%命中?深入解析内部语法树逻辑

第一章&#xff1a;Cirq 代码补全的语法规则Cirq 是由 Google 开发的用于编写、模拟和运行量子电路的 Python 框架。在使用 Cirq 进行开发时&#xff0c;启用代码补全功能可以显著提升编码效率。代码补全依赖于正确的语法规则和类型提示机制&#xff0c;IDE&#xff08;如 VS C…

作者头像 李华
网站建设 2026/4/18 8:37:05

35、深入探索GDB调试技术

深入探索GDB调试技术 1. GDB命令文件 在每次运行GDB时,有些操作是必须要做的,比如设置sysroot。为了方便,可以将这些命令放到一个命令文件中,每次启动GDB时自动运行。GDB会按以下顺序读取命令: 1. 读取 $HOME/.gdbinit 文件。 2. 读取当前目录下的 .gdbinit 文件。…

作者头像 李华