news 2026/4/24 2:23:30

基于 PyTorch 的 U-Net 训练代码。洪水区域分割数据集

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
基于 PyTorch 的 U-Net 训练代码。洪水区域分割数据集

灾害类-洪水区域分割数据集,增强版

该数据集包含洪水结构位置及其各自掩膜的图像(主要用于陆地和水域的二值分割)。
它包含3.4k 张图像,这些图像是通过扩充之前发布的洪水区域分割数据集获得的。

洪水分段数据和模型在洪水时期对于合理的管理和规划至关重要。
所有图片均为 png 格式,尺寸为 (512,512)。

数据集特点:

  • 图像和蒙版与其名称对应映射,即图像和蒙版对具有相同的文件名

3.4K仗图片以及对应蒙版图片

这是一份非常实用的洪水区域语义分割数据集。基于你提供的信息,我为你整理了数据集的详细统计表,并编写了基于PyTorchU-Net训练代码。

数据集概览

该数据集专注于洪水场景下的水域与陆地分割,经过数据增强后规模适中,适合用于训练鲁棒性较强的分割模型。

统计项数值 / 描述
总样本数3,400 张 (图像 + 掩膜对)
图像格式PNG
图像尺寸512 × 512 像素
任务类型二值语义分割 (Binary Segmentation)
类别数2类 (背景/陆地, 洪水/水域)
数据特点经过数据增强,包含不同光照、角度的洪水场景

注意:数据集中图像和掩膜文件名一一对应,这是语义分割数据集的标准格式。

U-Net 训练代码

以下代码包含 U-Net 模型定义、自定义数据集加载器以及训练循环。请确保你的目录结构如下:

flood_dataset/ ├── images/ (存放3400张原图) └── masks/ (存放3400张掩膜图,像素值应为0或255)
1. 模型定义与工具 (model.py)

这里定义了一个标准的 U-Net 架构。

importtorchimporttorch.nnasnnclassDoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def__init__(self,in_channels,out_channels):super().__init__()self.double_conv=nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))defforward(self,x):returnself.double_conv(x)classDown(nn.Module):"""Downscaling with maxpool then double conv"""def__init__(self,in_channels,out_channels):super().__init__()self.maxpool_conv=nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels,out_channels))defforward(self,x):returnself.maxpool_conv(x)classUp(nn.Module):"""Upscaling then double conv"""def__init__(self,in_channels,out_channels,bilinear=True):super().__init__()ifbilinear:self.up=nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True)else:self.up=nn.ConvTranspose2d(in_channels//2,in_channels//2,kernel_size=2,stride=2)self.conv=DoubleConv(in_channels,out_channels)defforward(self,x1,x2):x1=self.up(x1)diffY=x2.size()[2]-x1.size()[2]diffX=x2.size()[3]-x1.size()[3]x1=nn.functional.pad(x1,[diffX//2,diffX-diffX//2,diffY//2,diffY-diffY//2])x=torch.cat([x2,x1],dim=1)returnself.conv(x)classUNet(nn.Module):def__init__(self,n_channels=3,n_classes=1,bilinear=True):super(UNet,self).__init__()self.n_channels=n_channels self.n_classes=n_classes self.bilinear=bilinear self.inc=DoubleConv(n_channels,64)self.down1=Down(64,128)self.down2=Down(128,256)self.down3=Down(256,512)self.down4=Down(512,1024)self.up1=Up(1024,512,bilinear)self.up2=Up(512,256,bilinear)self.up3=Up(256,128,bilinear)self.up4=Up(128,64,bilinear)self.outc=nn.Conv2d(64,n_classes,kernel_size=1)defforward(self,x):x1=self.inc(x)x2=self.down1(x1)x3=self.down2(x2)x4=self.down3(x3)x5=self.down4(x4)x=self.up1(x5,x4)x=self.up2(x,x3)x=self.up3(x,x2)x=self.up4(x,x1)logits=self.outc(x)returnlogits
2. 数据集加载与训练 (train.py)

这段代码处理数据增强(随机翻转、旋转)并执行训练循环。

importosimporttorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimporttransformsfromPILimportImageimportnumpyasnpfrommodelimportUNet# 导入上面定义的模型# ----------------------------# 1. 自定义数据集类# ----------------------------classFloodDataset(Dataset):def__init__(self,image_dir,mask_dir,transform=None):self.image_dir=image_dir self.mask_dir=mask_dir self.transform=transform self.images=os.listdir(image_dir)def__len__(self):returnlen(self.images)def__getitem__(self,index):img_path=os.path.join(self.image_dir,self.images[index])# 假设掩膜文件名与图片一致,只是扩展名可能不同(这里假设都是png)mask_path=os.path.join(self.mask_dir,self.images[index])image=Image.open(img_path).convert("RGB")mask=Image.open(mask_path).convert("L")# 读取为灰度图ifself.transform:image=self.transform(image)# 对Mask只做ToTensor,不做归一化,因为它是类别标签mask=transforms.ToTensor()(mask)# 将mask二值化 (0 or 1),假设原图mask是0和255mask[mask>=0.5]=1.0mask[mask<0.5]=0.0returnimage,mask# ----------------------------# 2. 训练配置与主循环# ----------------------------deftrain():# 参数配置device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')epochs=50batch_size=8learning_rate=1e-4image_dir='./flood_dataset/images'mask_dir='./flood_dataset/masks'# 数据增强与预处理transform=transforms.Compose([transforms.Resize((512,512)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(10),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]),])# 加载数据集dataset=FloodDataset(image_dir,mask_dir,transform=transform)# 简单的 8:2 划分train_size=int(0.8*len(dataset))val_size=len(dataset)-train_size train_dataset,val_dataset=torch.utils.data.random_split(dataset,[train_size,val_size])train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)val_loader=DataLoader(val_dataset,batch_size=batch_size,shuffle=False)# 初始化模型model=UNet(n_channels=3,n_classes=1).to(device)criterion=nn.BCEWithLogitsLoss()# 二值分割常用损失函数optimizer=optim.Adam(model.parameters(),lr=learning_rate)print(f"开始训练,设备:{device}")forepochinrange(epochs):model.train()epoch_loss=0forimages,masksintrain_loader:images=images.to(device)masks=masks.to(device)optimizer.zero_grad()outputs=model(images)loss=criterion(outputs,masks)loss.backward()optimizer.step()epoch_loss+=loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss:{epoch_loss/len(train_loader):.4f}")# 每个epoch保存一次模型torch.save(model.state_dict(),f"unet_flood_epoch_{epoch+1}.pth")if__name__=="__main__":train()

训练建议

  1. 损失函数:对于洪水分割这种前景(水)和背景(陆地)可能不平衡的任务,BCEWithLogitsLoss是基础。如果效果不佳,建议尝试Dice Loss或两者的加权组合。
  2. 数据增强:代码中加入了随机水平翻转和旋转。由于洪水场景具有方向不确定性,这些增强非常有效。
  3. 后处理:在预测阶段,模型输出的是概率图(0-1之间),通常取阈值 0.5 进行二值化:prediction = torch.sigmoid(output) > 0.5
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/24 2:23:20

3.1 Python 条件语句(if/elif/else)教程

Python基础学习教程:条件语句(if/elif/else) 在Python编程中,条件语句用于根据不同的条件执行不同的代码块。这类似于日常生活中的决策:如果下雨,我就带伞;否则,我就不带。Python使用if、elif(else if的缩写)和else关键字来实现这种逻辑。本教程将详细解释这些语句的…

作者头像 李华
网站建设 2026/4/24 2:22:17

USB 3.0 PHY调优实战:从眼图分析到链路稳定性优化(附实测数据)

USB 3.0 PHY调优实战&#xff1a;从眼图分析到链路稳定性优化 作为一名长期奋战在硬件设计一线的工程师&#xff0c;我深知USB 3.0 PHY层调优对系统稳定性的关键影响。每当遇到数据传输不稳定、设备频繁断开或速率不达标的问题时&#xff0c;PHY层的信号完整性往往是罪魁祸首。…

作者头像 李华
网站建设 2026/4/24 2:21:17

基于OpenCV的Java人脸识别系统开发实战

1. 项目概述&#xff1a;基于OpenCV的Java人脸识别系统人脸识别技术已经从实验室走向了日常生活&#xff0c;从手机解锁到门禁系统无处不在。而OpenCV作为计算机视觉领域的瑞士军刀&#xff0c;配合Java的跨平台特性&#xff0c;可以快速构建一套实用的人脸识别系统。我在过去三…

作者头像 李华
网站建设 2026/4/24 2:17:21

3大实施策略解决ESP32固件烧录与系统恢复问题

3大实施策略解决ESP32固件烧录与系统恢复问题 【免费下载链接】arduino-esp32 Arduino core for the ESP32 项目地址: https://gitcode.com/GitHub_Trending/ar/arduino-esp32 ESP32作为一款广泛应用于物联网和嵌入式开发的微控制器&#xff0c;其Arduino核心支持为开发…

作者头像 李华