本科毕设里的“三座大山”
做深度学习毕设,最怕的不是调不出 SOTA,而是被这三件事反复折磨:
- 环境依赖冲突——今天还能跑,明天
pip install一下就全红; - 训练过程黑盒——loss 曲线靠手截屏,最佳权重靠“感觉”;
- 结果无法复现——同一份代码,同一张显卡,两次跑 mAP 能差 2 个点。
我当年也被折磨得怀疑人生,后来干脆用“AI 辅助开发”的思路给自己搭了条轻量级流水线,把实验、记录、部署三件事一次性封装好。毕设答辩时,老师一句“你这结果能复现吗?”我直接git checkout+bash train.sh现场跑通,分数自然水涨船高。
下面把这套模板拆开讲,200 行代码就能带走,拿走不谢。
技术选型:别一上来就“重武器”
| 维度 | 原生 PyTorch | PyTorch Lightning | 结论 |
|---|---|---|---|
| 样板代码 | 多 | 少 | Lightning 把训练循环抽象掉,毕设够用 |
| 学习曲线 | 平缓 | 有门槛 | 掌握LightningModule和Trainer即可 |
| 日志生态 | 自己接 | 一键开关 | W&B 比 TensorBoard 更适合“上传即分享” |
因此选型如下:
- 框架:PyTorch Lightning 2.1+
- 实验管理:Weights & Biases(wandb)
- 环境隔离:conda +
environment.yml锁定版本 - 硬件:单卡 2080Ti 即可,多卡反而增加调试复杂度
200 行搞定“配置驱动”流水线
目录先拆好,拒绝面条代码:
dataset/ ├─ __init__.py ├─ mnist_csv.py # 示例:把 MNIST 放 CSV 里读 models/ ├─ __init__.py ├─ lenet.py # 经典 LeNet,易改 configs/ └─ mnist_lenet.yaml # 一条命令换数据集/模型 train.py utils/ ├─ io.py └─ system.py1. 配置文件:把“魔法数字”赶到 yaml
# mnist_lenet.yaml data: class_path: dataset.mnist_csv.MNISTCSVDataModule data_dir: ./data batch_size: 64 model: class_path: models.lenet.LitLeNet num_classes: 10 lr: 1.0e-3 trainer: max_epochs: 10 accelerator: gpu devices: 1 seed: 42 wandb_project: dl_grad一条python train.py --config configs/mnist_lenet.yaml就能跑,换数据集只改两行。
2. 数据模块:保证“幂等”下载+缓存
# dataset/mnist_csv.py class MNISTCSVDataModule(LightningDataModule): def __init__(self, data_dir: str, batch_size: int = 64): super().__init__() self.data_dir = Path(data_dir) self.batch_size = batch_size self.prepare_data_per_node = True # 多卡时只下一份 def prepare_data(self): # 幂等:文件已存在直接跳过 if (self.data_dir / "mnist_train.csv").exists(): return download_and_extract(self.data_dir) # 自己封装的下载函数 def setup(self, stage=None): # 这里做 train/val 拆分,保证每次随机种子固定 df = pd.read_csv(self.data_dir / "mnist_train.csv") train_df, val_df = train_test_split( df, test_size=0.1, random_state=42 ) ...幂等性靠“文件存在即跳过”实现,错误处理把下载异常抛给上层,脚本直接 exit code1,防止静默失败。
3. 模型模块:只关心“前向+优化”
# models/lenet.py class LitLeNet(LightningModule): def __init__(self, num_classes: int = 10, lr: float = 1e-3): super().__init__() self.save_hyperparameters() self.model = LeNet(num_classes) self.criterion = nn.CrossEntropyLoss() def forward(self, x): return self.model(x) def training_step(self, batch, _): x, y = batch logits = self(x) loss = self.criterion(logits, y) self.log("train_loss", loss, prog_bar=True) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)Lightning 自动把backward和zero_grad做完,毕设里少写 30% 代码,降低出错概率。
4. 训练入口:解析配置 → 实例化 → 开跑
# train.py import yaml from jsonargparse import ArgumentParser from pytorch_lightning import Trainer from utils.system import set_seed def main(): parser = ArgumentParser() parser.add_argument("--config", required=True) args = parser.parse_args() cfg = yaml.safe_load(open(args.config)) set_seed(cfg["seed"]) # 固定随机种子 data_module = instantiate(cfg["data"]) # 动态 import model = instantiate(cfg["model"]) wandb_logger = WandbLogger(project=cfg["wandb_project"], save_dir="./logs") trainer = Trainer(**cfg["trainer"], logger=wandb_logger) trainer.fit(model, data_module) if __name__ == "__main__": main()instantiate是 PyTorch Lightning 自带的工厂函数,不用自己写一堆if-else。
性能冷启动:GPU 内存与调试成本
- 第一次
wandb login会拉取 30 M 左右依赖,校园网慢的话提前换源; - Lightning 的
Trainer(devices=1)在 2080Ti 上跑 MNIST,batch=64 只占 1.1 G,毕设常见 224×224 输入也远没爆显存; - 若用多卡,
strategy='ddp'会一次性复制模型,冷启动多 400 M,建议先单卡调通再横向扩展; - 关闭
wandb的在线同步可省 5% 训练时间:export WANDB_MODE=offline,答辩前再wandb sync。
生产环境避坑指南
| 坑 | 现象 | 解法 |
|---|---|---|
| 随机种子漏掉 | 两次结果差 2% 以上 | 在set_seed里同时设置torch.cuda.random.manual_seed_all |
| 路径硬编码 | 换电脑找不到/home/aaa/data | 所有路径读自配置文件,支持Path.relative_to |
| 模型序列化兼容性 | Lightning 升级后load_from_checkpoint失败 | 把 Lightning 版本写进environment.yml,并在 README 提示 |
| 日志爆炸 | wandb 上传 10 G checkpoint | ModelCheckpoint(save_top_k=1)只保留最佳,关闭save_last |
| 中文路径 | Windows 下cv2.imread返回 None | 统一用Path.as_posix()再丢给 OpenCV |
一张图看懂流水线
把模板改成“你的毕设”
- 换数据:在
dataset/里新建my_dataset.py,继承LightningDataModule,实现setup()和*_dataloader(); - 换模型:把
models/lenet.py里的LeNet替换成 ResNet、ViT、甚至 Swin; - 调超参:直接在 yaml 里改
lr、max_epochs,wandb 会自动记录每次 ablation; - 多任务:用 Lightning 的
Task和DataModule多实例,共享一个Trainer,一次fit就能出多份指标。
结尾
整套代码我放在校内 GitLab,每年新生git clone后平均 2 小时就能跑出第一条 loss 曲线。毕设不是发论文,别在工程细节上内耗——把环境、记录、复现交给流水线,你只管折腾想法。
下一步?试试把你的模板搬到 Kaggle 免费 GPU 上,或者把数据模块换成自己实验室的病理切片,看能不能 48 小时复现顶会代码。流水线在手,剩下的就是创意和调参的快乐了。