AI 辅助开发实战:基于 Python 的茶叶识别毕设系统设计与优化
配图:一张典型“茶叶叶片”数据集示例图
1. 毕设三大痛点:数据少、调参盲、部署难
标注数据少
实验室只给了 5 类茶叶、每类 200 张手机随手拍,分辨率 800×600,背景杂乱,光线差异大。传统手工标注费时,且类别极度不平衡——“龙井”样本数是“普洱”的 3 倍。调参盲目
多数同学把model.fit()一跑就完事,准确率 70% 卡住。学习率、增广策略、 backbone 选谁、是否冻结,全靠“玄学”网格搜索,GPU 机时又有限。部署困难
训练完拿到.pth文件,却卡在“怎么给前端同学调用”。Flask 一并发就崩,Docker 镜像 4 GB,树莓派直接 OOM;老师还要求 HTTPS + 输入校验,瞬间头大。
2. AI 辅助工具实测:谁更懂“茶叶”?
我同时开了三款插件做对比,统一提示语:“write PyTorch Dataset for tea leaf classification with Albumentations”。结果如下:
| 工具 | 代码完整度 | 是否自动 import | 中文注释 | 额外建议 |
|---|---|---|---|---|
| GitHub Copilot | 90% | 自动给出get_transforms()函数 | ||
| Amazon CodeWhisperer | 85% | 额外推荐timm模型 | ||
| Tabnine | 75% | 只给类骨架 |
结论:Copilot 对“数据增强 + 自定义 Dataset”场景最友好;CodeWhisperer 在“模型选型”阶段提示更激进,可直接呼出timm.create_model('mobilenetv2_100', pretrained=True),节省查文档时间。
3. 核心实现:从增广到 Flask API
3.1 数据增强:Albumentations 三板斧
import albumentations as A from albumentations.pytorch import ToTensorV2 train_tf = A.Compose([ A.RandomResizedCrop(224, 224, scale=(0.7, 1.0)), A.HorizontalFlip(p=0.5), A.RandomRotate90(), A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.5), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ])AI 插件直接补全了ColorJitter参数范围,避免我反复查官方文档。
3.2 轻量级 backbone 对比
| 模型 | 参数量 | 量化后大小 | CPU 延迟 (224×224) | 准确率@90epoch |
|---|---|---|---|---|
| MobileNetV2 | 2.2 M | 5.1 MB | 28 ms | 92.4 % |
| ResNet18 | 11.7 M | 22 MB | 65 ms | 93.1 % |
结论:MobileNetV2 牺牲 0.7 个百分点,换来 4 倍体积压缩、2.3 倍速度提升,毕设树莓派 demo 场景更香。
3.3 Flask API 封装(带输入校验)
from flask import Flask, request, jsonify from werkzeug.utils import secure_filename import torch, torchvision.transforms as T from PIL import Image app = Flask(__name__) model = torch.load('mobilenetv2_tea.pth', map_location='cpu') model.eval() def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg'} @app.route('/predict', methods=['POST']) def predict(): if 'file' not in request.files: return jsonify(error='No file part'), 400 file = request.files['file'] if not allowed_file(file.filename): return jsonify(error='Invalid extension'), 400 img = Image.open(file).convert('RGB') tf = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) x = tf(img).unsqueeze(0) with torch.no_grad(): out = model(x) pred = int(out.argmax(1)) return jsonify(label=pred)Copilot 在写校验allowed_file函数时,自动补全了白名单集合,避免我手写字符串拼写出错。
4. Clean Code 训练脚本(关键片段)
# train.py import torch, os, argparse, logging from torch.utils.data import DataLoader from torch.optim.lr_scheduler import CosineAnnealingLR from model import get_model from dataset import TeaDataset logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--data_root', required=True) parser.add_argument('--epochs', type=int, default=90) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--num_classes', type=int, default=5) parser.add_argument('--save_path', default='best.pth') return parser.parse_args() def train_one_epoch(model, dl, criterion, opt, device): model.train() running_loss, correct, total = 0.0, 0, 0 for x, y in dl: x, y = x.to(device), y.to(device) opt.zero_grad() out = model(x) loss = criterion(out, y) loss.backward() opt.step() running_loss += loss.item() * x.size(0) total += y.size(0) correct += (out.argmax(1) == y).sum().item() return running_loss/total, correct/total def main(): args = get_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') train_ds = TeaDataset(args.data_root, split='train', transforms=train_tf) val_ds = TeaDataset(args.data_root, split='val', transforms=val_tf) train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4) val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4) model = get_model('mobilenetv2', args.num_classes, pretrained=True).to(device) criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs) best_acc = 0.0 for epoch in range(1, args.epochs+1): tr_loss, tr_acc = train_one_epoch(model, train_dl, criterion, optimizer, device) val_loss, val_acc = evaluate(model, val_dl, criterion, device) scheduler.step() logging.info(f'Epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} | val loss {val_loss:.4f} acc {val_acc:.3f}') if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), args.save_path) logging.info(f'Best val acc {best_acc:.3f}') if __name__ == '__main__': main()亮点
- 日志用
logging而非print,方便重定向到文件 argparse集中管理超参,毕设答辩时老师一眼看清label_smoothing与CosineAnnealingLR组合,CodeWhisperer 提示后实验发现能再提 1.2 % 准确率
5. 性能 & 安全:大小、延迟、内存
模型大小
FP32:5.1 MB;INT8 量化(Post-training)后 1.3 MB,准确率下降 0.3 %,可接受。推理延迟
- i7-12700H CPU 单线程:28 ms
- 树莓派 4B CPU:110 ms
- 加 ONNXRuntime + CPUProvider:降至 85 ms
内存占用
Flask 单进程常驻 180 MB;并发 4 线程峰值 350 MB,树莓派 4G 版本无压力。HTTPS & 输入校验
使用gunicorn+nginx反向代理,Let's Encrypt 自动续签。nginx 层加client_max_body_size 5M;防止大图片打爆内存;后端PIL打开前用Image.verify()防炸弹图。
6. 生产环境避坑指南
类别不平衡
采用WeightedRandomSampler给少数类加权重,或离线复制+轻微 ColorJitter 扩增,避免盲目class_weight导致过拟合。模型冷启动
量化后第一次推理仍要建缓存,把model(torch.randn(1,3,224,224))放在 Flask 启动脚本里跑一遍,用户请求不再卡顿。过拟合
- Albumentations 增广必须在线做,离线扩增 10 倍 JPG 会炸硬盘
- 训练阶段加
CutMix或RandAugment,CodeWhisperer 直接给出timm.data.mixup,一行代码即可 - 冻结 backbone 前 10 % epoch,让分类头先热身,之后再解冻一起微调,提升 0.8 % 准确率
日志与监控
记录每条预测耗时、返回置信度,方便后续用 Grafana 看漂移。异常样本(置信度<0.6)自动写入review/目录,定期人工复核,实现半自动迭代。
配图:树莓派实拍部署图
7. 下一步:换 backbone & ONNX 压测
把MobileNetV3-Small或EfficientNet-Lite0拖进timm.create_model()做替换,只需改一行;再用torch.onnx.export转模型,用onnxruntime-gpu压测 1000 张图片,看能否把延迟压到 50 ms 以下。欢迎你也来试试,然后把结果告诉我——AI 辅助开发,让毕设不再熬夜。