news 2026/4/17 22:39:04

day37打卡

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
day37打卡

@浙大疏锦行

import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader, random_split import pandas as pd import numpy as np from sklearn.preprocessing import StandardScaler from sklearn.impute import SimpleImputer # ========================================== # 1. 数据预处理与加载 # ========================================== def preprocess_data(file_path): # 读取数据 df = pd.read_csv(file_path) # 简单的清洗逻辑 if 'Id' in df.columns: df = df.drop(columns=['Id']) # 处理异常值 (例如: 99999999 通常是缺失值的标记) if 'Current Loan Amount' in df.columns: df['Current Loan Amount'] = df['Current Loan Amount'].replace(99999999.0, np.nan) # 分离特征和标签 # 假设目标列名为 'Credit Default' target_col = 'Credit Default' X = df.drop(columns=[target_col]) y = df[target_col] # 区分数值型和类别型特征 numeric_cols = X.select_dtypes(include=['float64', 'int64']).columns categorical_cols = X.select_dtypes(include=['object']).columns # 填充缺失值 (数值型填均值,类别型填众数) imputer_num = SimpleImputer(strategy='mean') X[numeric_cols] = imputer_num.fit_transform(X[numeric_cols]) imputer_cat = SimpleImputer(strategy='most_frequent') X[categorical_cols] = imputer_cat.fit_transform(X[categorical_cols]) # 独热编码 (One-Hot Encoding) X = pd.get_dummies(X, columns=categorical_cols, drop_first=True) # 标准化 (Standard Scaling) scaler = StandardScaler() X = pd.DataFrame(scaler.fit_transform(X), columns=X.columns) return X, y class CreditDataset(Dataset): def __init__(self, X, y): self.X = torch.tensor(X.values, dtype=torch.float32) self.y = torch.tensor(y.values, dtype=torch.float32).unsqueeze(1) def __len__(self): return len(self.y) def __getitem__(self, idx): return self.X[idx], self.y[idx] # ========================================== # 2. 模型定义 # ========================================== class CreditModel(nn.Module): def __init__(self, input_dim): super(CreditModel, self).__init__() # 定义一个简单的全连接神经网络 self.layer1 = nn.Linear(input_dim, 64) self.relu = nn.ReLU() self.layer2 = nn.Linear(64, 32) self.output = nn.Linear(32, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.relu(self.layer1(x)) x = self.relu(self.layer2(x)) x = self.sigmoid(self.output(x)) return x # ========================================== # 3. 早停策略 (Early Stopping) # ========================================== class EarlyStopping: def __init__(self, patience=5, min_delta=0): """ Args: patience (int): 当验证集损失在多少个epoch内没有改善时停止训练 min_delta (float): 被认为是改善的最小变化量 """ self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_loss = None self.early_stop = False def __call__(self, val_loss): if self.best_loss is None: self.best_loss = val_loss elif val_loss > self.best_loss - self.min_delta: self.counter += 1 if self.counter >= self.patience: self.early_stop = True else: self.best_loss = val_loss self.counter = 0 # ========================================== # 4. 训练函数 # ========================================== def train_epoch(model, dataloader, criterion, optimizer): model.train() running_loss = 0.0 for X_batch, y_batch in dataloader: optimizer.zero_grad() y_pred = model(X_batch) loss = criterion(y_pred, y_batch) loss.backward() optimizer.step() running_loss += loss.item() return running_loss / len(dataloader) def validate_epoch(model, dataloader, criterion): model.eval() running_loss = 0.0 with torch.no_grad(): for X_batch, y_batch in dataloader: y_pred = model(X_batch) loss = criterion(y_pred, y_batch) running_loss += loss.item() return running_loss / len(dataloader) # ========================================== # 5. 主程序执行 # ========================================== if __name__ == "__main__": # --- A. 数据准备 --- file_path = 'data.csv' # 请确保目录下有此文件 X, y = preprocess_data(file_path) # 划分训练集和验证集 dataset = CreditDataset(X, y) train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) input_dim = X.shape[1] # --- B. 初始训练并保存权重 --- print(">>> 开始初始训练 (Phase 1)...") model = CreditModel(input_dim) criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 假设先训练 10 轮 for epoch in range(10): train_loss = train_epoch(model, train_loader, criterion, optimizer) print(f"Epoch {epoch+1}/10 - Loss: {train_loss:.4f}") print(">>> 保存模型权重到 'credit_model.pth'...") torch.save(model.state_dict(), 'credit_model.pth') # --- C. 加载权重并继续训练 (带早停) --- print("\n>>> 加载权重并继续训练 50 轮 (Phase 2)...") # 1. 初始化新模型实例 new_model = CreditModel(input_dim) # 2. 加载之前保存的权重 new_model.load_state_dict(torch.load('credit_model.pth')) # 3. 定义新的优化器和早停实例 new_optimizer = optim.Adam(new_model.parameters(), lr=0.001) early_stopping = EarlyStopping(patience=5, min_delta=0.001) # 4. 继续训练 50 轮 additional_epochs = 50 for epoch in range(additional_epochs): train_loss = train_epoch(new_model, train_loader, criterion, new_optimizer) val_loss = validate_epoch(new_model, val_loader, criterion) print(f"Resumed Epoch {epoch+1}/{additional_epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f}") # 检查是否需要早停 early_stopping(val_loss) if early_stopping.early_stop: print("!!! 验证集损失不再下降,触发早停 (Early Stopping) !!!") break print("训练结束。")
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/18 3:26:55

采购计划没抓手?两小时,我搭了一套以销定采的供应链系统

很多公司在做供应链的时候,总觉得自己不是没有数据,就是没有方法。最后变成一句话:计划没抓手,采购全靠经验,结果一拍脑袋就要货,一拍屁股又让供应商等等等等。那天我去看一家工厂的流程,销售预…

作者头像 李华
网站建设 2026/4/18 3:26:12

构建个人专属知识库:访答知识库深度解析与实战指南

构建个人专属知识库:访答知识库深度解析与实战指南 在信息爆炸的时代,高效管理个人知识已成为提升工作效率的关键。本地私有知识库因其数据安全、离线可用等优势,正受到越来越多人的青睐。在众多选择中,知识库以其独特的定位和功能…

作者头像 李华
网站建设 2026/4/17 20:00:16

8个课堂汇报神器,本科生AI工具推荐与使用攻略

8个课堂汇报神器,本科生AI工具推荐与使用攻略 论文写作的“三座大山”:时间、重复率与效率的困境 对于本科生来说,课堂汇报和论文写作从来都不是轻松的任务。从选题到文献综述,从框架搭建到内容撰写,每一个环节都充满了…

作者头像 李华
网站建设 2026/4/17 1:59:48

“从可控性到自主反思”这个短语,似乎描述了一种从外部控制(或自我控制)向内在自主反思的转变过程

“从可控性到自主反思”这个短语,似乎描述了一种从外部控制(或自我控制)向内在自主反思的转变过程。这在心理学、人工智能(AI)和教育等领域都有深刻的体现,代表了个体或系统从被动受控、依赖外部约束&#…

作者头像 李华
网站建设 2026/4/16 19:12:10

SQL中表删除与表修改

表删表drop table [if not exist] students;表修改ALTER语句使用 ALTER TABLE 语句追加, 修改, 或删除列的语法add增加字段:alter table students add [column] dateT date;设置默认值:alter table students add dateT date DEFAULT "2025-12-12&qu…

作者头像 李华