news 2026/6/10 14:57:37

Vlm-Transformer_demo

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Vlm-Transformer_demo
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import random # ===================== 1. 准备数据(字符级语料) ===================== # 简单语料(自己构造,无需下载) #训练样本数: 89 | 词汇表字符: [' ', 'a', 'c', 'd', 'e', 'f', 'h', 'i', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'w', 'y'] corpus = [ "hello transformer", "transformer is a powerful model", "pytorch transformer demo", "transformer uses self attention", "attention is the core of transformer" ] # 收集所有唯一字符,建立字符→索引、索引→字符映射 all_chars = sorted(list(set("".join(corpus)))) # 去重+排序 char2idx = {char: idx for idx, char in enumerate(all_chars)} idx2char = {idx: char for char, idx in char2idx.items()} vocab_size = len(all_chars) # 词汇表大小(字符数) seq_len = 10 # 输入序列长度(取前10个字符,预测第11个) # 生成训练数据:输入是前seq_len个字符,目标是第seq_len+1个字符 def generate_data(corpus, seq_len): inputs = [] targets = [] for sentence in corpus: # 将句子转成字符索引序列 sentence_idx = [char2idx[c] for c in sentence] # 滑动窗口生成样本(确保长度足够) for i in range(len(sentence_idx) - seq_len): input_seq = sentence_idx[i:i+seq_len] target_char = sentence_idx[i+seq_len] inputs.append(input_seq) targets.append(target_char) # 转成Tensor return torch.tensor(inputs), torch.tensor(targets) # 生成训练集 train_inputs, train_targets = generate_data(corpus, seq_len) print(f"训练样本数: {len(train_inputs)} | 词汇表字符: {all_chars}") # ===================== 2. 定义位置编码(Transformer必需) ===================== class PositionalEncoding(nn.Module): def __init__(self, embedding_dim, max_len=5000): super().__init__() # 预计算位置编码 position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, embedding_dim, 2) * (-torch.log(torch.tensor(10000.0)) / embedding_dim)) pe = torch.zeros(max_len, 1, embedding_dim) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) # 不参与训练的参数 def forward(self, x): # x形状: [seq_len, batch_size, embedding_dim] x = x + self.pe[:x.size(0)] return x # ===================== 3. 定义Transformer模型 ===================== class TransformerLM(nn.Module): def __init__(self, vocab_size, embedding_dim=16, nhead=2, num_layers=2): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) # 字符嵌入 self.pos_encoder = PositionalEncoding(embedding_dim) # 位置编码 # Transformer编码器(这里用编码器做语言模型,也可以用解码器) encoder_layers = nn.TransformerEncoderLayer( d_model=embedding_dim, # 输入维度(和嵌入维度一致) nhead=nhead, # 多头注意力的头数 dim_feedforward=64 # 前馈网络的隐藏层维度 ) self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers) self.fc = nn.Linear(embedding_dim, vocab_size) # 输出到词汇表 def forward(self, x): # x形状: [batch_size, seq_len] → 转成Transformer要求的[seq_len, batch_size] x = x.transpose(0, 1) # 嵌入+位置编码: [seq_len, batch_size, embedding_dim] x = self.embedding(x) x = self.pos_encoder(x) # Transformer编码: [seq_len, batch_size, embedding_dim] x = self.transformer_encoder(x) # 取最后一个时间步的输出(预测下一个字符): [batch_size, embedding_dim] x = x[-1, :, :] # 输出分类: [batch_size, vocab_size] x = self.fc(x) return x # ===================== 4. 训练模型 ===================== def train(): # 超参数 embedding_dim = 16 #字符嵌入维度(每个字符转成 16 维向量) nhead = 2 #多头 num_layers = 2 #编码器层数 batch_size = 4 #每次训练的样本数 epochs = 50 #训练轮数 lr = 0.001 #学习率(参数更新的步长) # 初始化模型、损失函数、优化器 #实例化我们定义的 Transformer 字符模型,把超参数传入(比如 embedding_dim=16) #此时模型的参数(QKV 权重、嵌入层权重等)都是随机初始化的,还没学到任何东西。 model = TransformerLM(vocab_size, embedding_dim, nhead, num_layers) #交叉熵 #损失函数 #选择交叉熵作为损失函数 #交叉熵损失是分类任务的 “标配” criterion = nn.CrossEntropyLoss() # 分类损失(预测字符) #Adam是SGD的升级版,自带自适应学习 #优化器要更新的是模型的所有可训练参数(QKV、嵌入层、全连接层等) optimizer = optim.Adam(model.parameters(), lr=lr) # 训练循环 #切换训练模式 model.train() for epoch in range(epochs): total_loss = 0.0 #初始化 “本轮 Epoch 的总损失”,用于统计整个 Epoch 的平均损失(损失越小→模型预测越准) # 随机打乱数据(按batch处理) indices = torch.randperm(len(train_inputs)) for i in range(0, len(train_inputs), batch_size):#取一个步长=batch 例如4 88个样本 22 # 取一个batch batch_idx = indices[i:i+batch_size]#取当前批次的 例如 4 那就是4-7 batch_inputs = train_inputs[batch_idx]#取当前批次的输入序列 batch_targets = train_targets[batch_idx]#取当前批次的目标字符 # 前向传播 outputs = model(batch_inputs)#输出 “预测结果” loss = criterion(outputs, batch_targets)#用交叉熵损失函数,计算 “模型预测得分” 和 “真实目标字符” 的差距 # 反向传播+优化 optimizer.zero_grad()#清空模型所有参数的梯度 loss.backward()#反向传播计算梯度 optimizer.step()#用梯度更新参数 total_loss += loss.item() # 每5轮打印一次损失 if (epoch + 1) % 5 == 0: print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_inputs):.4f}") print("训练完成!") return model # ===================== 5. 测试生成(输入前缀,生成后续字符) ===================== def generate_text(model, prefix, max_len=20): model.eval() # 将前缀转成字符索引 input_seq = [char2idx[c] for c in prefix] with torch.no_grad(): for _ in range(max_len): # 取最后seq_len个字符作为输入(不足则补0) current_input = torch.tensor([input_seq[-seq_len:] if len(input_seq)>=seq_len else [0]*(seq_len-len(input_seq)) + input_seq]) # 预测下一个字符 output = model(current_input) next_char_idx = output.argmax(dim=1).item() input_seq.append(next_char_idx) # 如果生成空格或结束符(这里用空格当结束),提前停止 if idx2char[next_char_idx] == " ": break # 转成字符 return "".join([idx2char[idx] for idx in input_seq]) # ===================== 主函数(直接运行) ===================== if __name__ == "__main__": # 训练模型 trained_model = train() # 测试生成(输入不同前缀) prefixes = ["trans", "att", "pyt"] for prefix in prefixes: generated = generate_text(trained_model, prefix) print(f"\n输入前缀: '{prefix}' → 生成结果: '{generated}'")
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/10 9:56:47

深度解析|当 Prometheus 遇见大模型:解密下一代智能监控体系

导读在云原生时代,Prometheus Alertmanager 虽然解决了“看得见”的问题,却无法解决“看得懂”和“看得早”的难题。运维团队往往陷入“故障发生->收到告警->紧急救火”的被动循环。 本文将探讨如何利用 AI 大模型技术赋能现有监控体系&#xff0…

作者头像 李华
网站建设 2026/6/10 9:53:45

L298N典型应用电路搭建手把手教程

手把手教你用L298N驱动直流电机:从零搭建稳定控制电路你有没有遇到过这样的情况?写好了Arduino程序,信心满满地给小车通电,结果电机纹丝不动——或者只转一个方向,还“嗡嗡”发热。别急,问题很可能出在电机…

作者头像 李华
网站建设 2026/6/9 22:01:54

Java Web 车辆管理系统系统源码-SpringBoot2+Vue3+MyBatis-Plus+MySQL8.0【含文档】

摘要 随着城市化进程的加快和私家车保有量的持续增长,车辆管理成为城市治理的重要课题。传统车辆管理方式依赖人工登记和纸质档案,存在效率低下、数据易丢失、查询困难等问题。信息化技术的普及为车辆管理提供了新的解决方案,通过构建智能化的…

作者头像 李华
网站建设 2026/6/10 9:50:23

IT自动分派单据如何实现?从规则到智能分派全解读

在IT运维现场,工单处理是否高效往往已经由“分派”确定。在系统上线初期很多企业还能依靠人工判断而随着系统数量、用户规模不断扩大即将由人工派单逐步成为瓶颈。正因为如此,IT自动分派单据已开始被越来越多IT团队视为基本能力兼运维流程中的关键一环&a…

作者头像 李华
网站建设 2026/6/9 18:30:55

nmodbus4类库在PLC通信中的应用完整指南

用 nmodbus4 打通工业通信——从零构建稳定可靠的 PLC 数据交互系统在现代工厂的控制室里,一台运行着 C# 编写的监控软件的工控机,正通过网线与远处的西门子 S7-1200 PLC 进行高速数据交换。温度、压力、电机状态实时刷新,一旦超过阈值&#…

作者头像 李华
网站建设 2026/6/10 13:19:49

零基础掌握HardFault异常处理机制的基本原理

破解HardFault之谜:从崩溃现场还原程序“死亡瞬间”你有没有遇到过这样的场景?代码烧进去,设备上电后一切正常,突然毫无征兆地卡死——没有日志、无法复现、JTAG一连才发现:程序停在了while(1)里,而调用栈清…

作者头像 李华