news 2026/4/18 8:53:54

图像自回归生成(Auto-regressive image generation)实战学习(二)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
图像自回归生成(Auto-regressive image generation)实战学习(二)

相关项目下载链接

训练框架

在开始实现相应模块功能之前,首先熟悉训练框架·train.py

1. 导入与模型字典构建

importinspectimportmathfromdatetimeimportdatetimefrompathlibimportPathimporttorchimportae,autoregressive,bsq# 自定义模型模块(AE/BSQ/自回归)# 收集ae/bsq模块中所有继承nn.Module的块级模型类patch_models={n:mforMin[ae,bsq]forn,mininspect.getmembers(M)ifinspect.isclass(m)andissubclass(m,torch.nn.Module)}# 收集autoregressive模块中所有继承nn.Module的自回归模型类ar_models={n:mforMin[autoregressive]forn,mininspect.getmembers(M)ifinspect.isclass(m)andissubclass(m,torch.nn.Module)}

2. 核心训练函数 train()

共包含三个部分:块级模型训练器 PatchTrainer、自回归模型训练器 AutoregressiveTrainer、模型保存回调 CheckPointer。
其中,

  1. 块级模型训练器 PatchTrainer专用于 AE/BSQ 模型。值得注意的是数据预处理过程,图像归一化的方式是(/255.0 - 0.5),将像素值映射到[-0.5, 0.5]而不是[0,1];损失函数采用MSE(均方误差),适配图像重构任务;优化器为AdamW,学习率1e-3;基于ImageDataset加载原始图像数据集。
  2. 自回归模型训练器 AutoregressiveTrainer专用于 AR 模型。使用交叉熵损失,适配令牌序列的分类预测任务;基于TokenDataset加载令牌化后的图像序列;优化器为AdamW,学习率1e-3。
  3. 模型保存回调 CheckPointer。模型保存的触发时机是在每个训练 epoch 结束后;有两种保存方式:一种是带时间戳的模型,保存方式为checkpoints/{时间戳}_{模型名}.pth;另一种是最新的模型,保存路径为当前目录下的{模型名}.pth

此外,还实现了模型加载 / 创建的逻辑。

deftrain(model_name_or_path:str,epochs:int=5,batch_size:int=64):importlightningasLfromlightning.pytorch.loggersimportTensorBoardLoggerfromdataimportImageDataset,TokenDatasetclassPatchTrainer(L.LightningModule):def__init__(self,model):super().__init__()self.model=modeldeftraining_step(self,x,batch_idx):x=x.float()/255.0-0.5x_hat,additional_losses=self.model(x)loss=torch.nn.functional.mse_loss(x_hat,x)self.log("train/loss",loss,prog_bar=True)fork,vinadditional_losses.items():self.log(f"train/{k}",v)returnloss+sum(additional_losses.values())defvalidation_step(self,x,batch_idx):x=x.float()/255.0-0.5withtorch.no_grad():x_hat,additional_losses=self.model(x)loss=torch.nn.functional.mse_loss(x_hat,x)self.log("validation/loss",loss,prog_bar=True)fork,vinadditional_losses.items():self.log(f"validation/{k}",v)ifbatch_idx==0:self.logger.experiment.add_images("input",(x[:64]+0.5).clamp(min=0,max=1).permute(0,3,1,2),self.global_step)self.logger.experiment.add_images("prediction",(x_hat[:64]+0.5).clamp(min=0,max=1).permute(0,3,1,2),self.global_step)returnlossdefconfigure_optimizers(self):returntorch.optim.AdamW(self.parameters(),lr=1e-3)deftrain_dataloader(self):dataset=ImageDataset("train")returntorch.utils.data.DataLoader(dataset,batch_size=batch_size,num_workers=4,shuffle=True)defval_dataloader(self):dataset=ImageDataset("valid")returntorch.utils.data.DataLoader(dataset,batch_size=4096,num_workers=4,shuffle=True)classAutoregressiveTrainer(L.LightningModule):def__init__(self,model):super().__init__()self.model=modeldeftraining_step(self,x,batch_idx):x_hat,additional_losses=self.model(x)loss=(torch.nn.functional.cross_entropy(x_hat.view(-1,x_hat.shape[-1]),x.view(-1),reduction="sum")/math.log(2)/x.shape[0])self.log("train/loss",loss,prog_bar=True)fork,vinadditional_losses.items():self.log(f"train/{k}",v)returnloss+sum(additional_losses.values())defvalidation_step(self,x,batch_idx):withtorch.no_grad():x_hat,additional_losses=self.model(x)loss=(torch.nn.functional.cross_entropy(x_hat.view(-1,x_hat.shape[-1]),x.view(-1),reduction="sum")/math.log(2)/x.shape[0])self.log("validation/loss",loss,prog_bar=True)fork,vinadditional_losses.items():self.log(f"validation/{k}",v)returnlossdefconfigure_optimizers(self):returntorch.optim.AdamW(self.parameters(),lr=1e-3)deftrain_dataloader(self):dataset=TokenDataset("train")returntorch.utils.data.DataLoader(dataset,batch_size=batch_size,num_workers=4,shuffle=True)defval_dataloader(self):dataset=TokenDataset("valid")returntorch.utils.data.DataLoader(dataset,batch_size=batch_size,num_workers=4,shuffle=True)classCheckPointer(L.Callback):defon_train_epoch_end(self,trainer,pl_module):fn=Path(f"checkpoints/{timestamp}_{model_name}.pth")fn.parent.mkdir(exist_ok=True,parents=True)torch.save(model,fn)torch.save(model,Path(__file__).parent/f"{model_name}.pth")# Load or create the modelifPath(model_name_or_path).exists():model=torch.load(model_name_or_path,weights_only=False)model_name=model.__class__.__name__else:model_name=model_name_or_pathifmodel_nameinpatch_models:model=patch_models[model_name]()elifmodel_nameinar_models:model=ar_models[model_name]()else:raiseValueError(f"Unknown model:{model_name}")# Create the lightning modelifisinstance(model,(autoregressive.Autoregressive)):l_model=AutoregressiveTrainer(model)else:l_model=PatchTrainer(model)timestamp=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")logger=TensorBoardLogger("logs",name=f"{timestamp}_{model_name}")trainer=L.Trainer(max_epochs=epochs,logger=logger,callbacks=[CheckPointer()])trainer.fit(model=l_model,)

3. 命令行启动

本项目借助fire库实现命令行参数解析,无需手动解析--epochs/--batch_size等参数,直接通过python train.py {模型名} --epochs 10启动训练。

if__name__=="__main__":fromfireimportFire Fire(train)

train.py核心使用方法如下:

# 训练块级自编码器python train.py PatchAutoEncoder--epochs5--batch_size64# 训练自回归模型python train.py AutoregressiveModel--epochs10--batch_size32# 加载已有模型续训python train.py checkpoints/2025-10-20_PatchAutoEncoder.pth--epochs10

加载数据

接下来熟悉这个项目是如何进行数据加载的,data.py模块定义两类 PyTorch 兼容的数据集类。
其中:

  • ImageDataset:加载原始 JPG 图像,提供缓存机制提升读取效率;
  • TokenDataset:加载令牌化后的图像张量(由tokenize.py生成),供自回归模型训练使用。。

1. 导入依赖库

frompathlibimportPathimporttorchfromPILimportImage# 自动定位数据集根目录:当前文件的父父目录下的data文件夹DATASET_PATH=Path(__file__).parent.parent/"data"

2. ImageDataset(原始图像数据集)

classImageDataset:def__init__(self,split:str,cache_images:bool=True):# 收集split(train/valid)目录下所有.jpg文件路径self.image_paths=list((DATASET_PATH/split).rglob("*.jpg"))# 初始化图像缓存列表,避免重复读取磁盘self._image_cache=[None]*len(self.image_paths)self._cache_images=cache_images# 是否开启缓存def__len__(self)->int:returnlen(self.image_paths)# 数据集总长度def__getitem__(self,idx:int)->torch.Tensor:# 优先读取缓存,无缓存则加载图像ifself._image_cache[idx]isnotNone:returnself._image_cache[idx]# 图像加载:PIL打开→转numpy数组→转torch.uint8张量(保持原始像素值)img=torch.tensor(np.array(Image.open(self.image_paths[idx])),dtype=torch.uint8)# 开启缓存则存入,后续复用ifself._cache_images:self._image_cache[idx]=imgreturnimg

3. TokenDataset(令牌化数据集)

classTokenDataset(torch.utils.data.TensorDataset):def__init__(self,split:str):# 加载令牌化后的张量文件(由tokenize.py生成)tensor_path=DATASET_PATH/f"tokenized_{split}.pth"ifnottensor_path.exists():# 文件不存在时给出明确提示,符合作业流程指引raiseFileNotFoundError(f"Tokenized dataset not found at{tensor_path}...")self.data=torch.load(tensor_path,weights_only=False)def__getitem__(self,idx:int)->torch.Tensor:# 返回长整型张量(适配自回归模型的离散令牌输入)returntorch.tensor(self.data[idx],dtype=torch.long)def__len__(self)->int:returnlen(self.data)

这两个数据集加载对象的使用方法如下所示:

# 加载训练集原始图像(用于AE/BSQ训练)fromdataimportImageDataset,TokenDataset train_img_ds=ImageDataset("train",cache_images=True)img_tensor=train_img_ds[0]# 取第0张图像,shape: (H, W, 3)# 加载训练集令牌数据(用于自回归模型训练)train_token_ds=TokenDataset("train")token_tensor=train_token_ds[0]# 取第0个令牌序列,shape: (序列长度,)# 配合DataLoader使用fromtorch.utils.dataimportDataLoader train_loader=DataLoader(train_token_ds,batch_size=64,shuffle=True)
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 6:30:12

AcFun视频下载终极指南:2025年最全离线保存解决方案

还在为无法离线观看A站精彩视频而烦恼吗?今天为大家带来一款功能强大的免费工具——AcFunDown,让你轻松实现视频批量下载、多格式支持、断点续传等实用功能,彻底告别网络限制!这款专为AcFun用户设计的下载工具采用直观的图形操作界…

作者头像 李华
网站建设 2026/4/18 3:46:16

5步快速掌握机器人仿真:从零搭建Go2四足机器人的终极指南

5步快速掌握机器人仿真:从零搭建Go2四足机器人的终极指南 【免费下载链接】go2_ros2_sdk Unofficial ROS2 SDK support for Unitree GO2 AIR/PRO/EDU 项目地址: https://gitcode.com/gh_mirrors/go/go2_ros2_sdk 想要在虚拟世界中安全地测试机器人算法吗&…

作者头像 李华
网站建设 2026/4/18 7:57:18

Foobar2000 逐字歌词终极配置:让每句歌词都精准同步

还在为传统歌词的粗糙同步而烦恼吗?想象一下,当你聆听心爱的歌曲时,每个字词都如同跳动在屏幕上的音符,与旋律完美契合——这就是 ESLyric-LyricsSource 为 Foobar2000 用户带来的沉浸式歌词体验。 【免费下载链接】ESLyric-Lyric…

作者头像 李华
网站建设 2026/4/18 5:41:50

G-Helper终极硬件优化完整指南:快速提升华硕设备性能

G-Helper终极硬件优化完整指南:快速提升华硕设备性能 【免费下载链接】g-helper Lightweight Armoury Crate alternative for Asus laptops. Control tool for ROG Zephyrus G14, G15, G16, M16, Flow X13, Flow X16, TUF, Strix, Scar and other models 项目地址…

作者头像 李华
网站建设 2026/4/18 10:18:30

OpenCore Configurator 终极配置指南:轻松打造完美黑苹果系统

OpenCore Configurator 终极配置指南:轻松打造完美黑苹果系统 【免费下载链接】OpenCore-Configurator A configurator for the OpenCore Bootloader 项目地址: https://gitcode.com/gh_mirrors/op/OpenCore-Configurator OpenCore Configurator 是一款专为黑…

作者头像 李华