以下是用纯 NumPy从零实现线性回归(Linear Regression)的完整、逐步讲解版本。
我们会实现两种主流方式:
- 闭式解(Normal Equation / 最小二乘法直接求解)—— 适合中小型数据集,一步求出最优解
- 梯度下降(Gradient Descent)—— 适合大数据集、可扩展到深度学习
两种方式都包含:
- 数据生成(带噪声的真实线性关系)
- 模型训练
- 损失曲线 / 参数收敛观察
- 预测与可视化
0. 数学核心(必须理解)
假设函数(单特征为例):
y ^ = w ⋅ x + b \hat{y} = w \cdot x + by^=w⋅x+b
损失函数(均方误差 MSE):
J ( w , b ) = 1 2 m ∑ i = 1 m ( y ^ ( i ) − y ( i ) ) 2 J(w,b) = \frac{1}{2m} \sum_{i=1}^m (\hat{y}^{(i)} - y^{(i)})^2J(w,b)=2m1i=1∑m(y^(i)−y(i))2
目标:最小化 J
方式1:闭式解(Normal Equation)
w = ( X T X ) − 1 X T y \mathbf{w} = (\mathbf{X}^T \mathbf{X})^{-1} \mathbf{X}^T \mathbf{y}w=(XTX)−1XTy
(w 包含截距 b)
方式2:梯度下降(批量梯度下降)
参数更新规则:
w : = w − α ∂ J ∂ w , b : = b − α ∂ J ∂ b w := w - \alpha \frac{\partial J}{\partial w}, \quad b := b - \alpha \frac{\partial J}{\partial b}w:=w−α∂w∂J,b:=b−α∂b∂J
其中偏导数(关键):
∂ J ∂ w = 1 m ∑ i = 1 m ( y ^ ( i ) − y ( i ) ) ⋅ x ( i ) \frac{\partial J}{\partial w} = \frac{1}{m} \sum_{i=1}^m (\hat{y}^{(i)} - y^{(i)}) \cdot x^{(i)}∂w∂J=m1i=1∑m(y^(i)−y(i))⋅x(i)
∂ J ∂ b = 1 m ∑ i = 1 m ( y ^ ( i ) − y ( i ) ) \frac{\partial J}{\partial b} = \frac{1}{m} \sum_{i=1}^m (\hat{y}^{(i)} - y^{(i)})∂b∂J=m1i=1∑m(y^(i)−y(i))
1. 完整代码实现(NumPy from scratch)
importnumpyasnpimportmatplotlib.pyplotasplt# -------------------------------# 生成带噪声的线性数据# -------------------------------np.random.seed(42)m=200# 样本数量true_w=3.14# 真实斜率true_b=8.5# 真实截距X=np.random.uniform(-5,10,size=(m,1))# 单特征,形状 (m,1)noise=np.random.randn(m,1)*3# 高斯噪声y=true_w*X+true_b+noise# y = 3.14x + 8.5 + noise# 增加一列全1,作为截距项(bias trick)X_b=np.c_[np.ones((m,1)),X]# 形状 (m, 2) → [1, x]# 可视化原始数据plt.figure(figsize=(10,6))plt.scatter(X,y,s=40,alpha=0.7,label="真实数据(带噪声)")plt.xlabel("X")plt.ylabel("y")plt.title("生成的数据(真实关系:y ≈ 3.14x + 8.5)")plt.grid(True,alpha=0.3)plt.legend()plt.show()方式一:闭式解(Normal Equation)—— 一行求解
# 闭式解: w = (X^T X)^(-1) X^T ytheta_best=np.linalg.inv(X_b.T @ X_b)@(X_b.T @ y)w_closed=theta_best[1][0]b_closed=theta_best[0][0]print(f"闭式解得到:w ={w_closed:.4f}, b ={b_closed:.4f}")print(f"真实值对比:w ={true_w}, b ={true_b}")# 画出拟合直线X_plot=np.linspace(-5,10,100).reshape(-1,1)X_plot_b=np.c_[np.ones((100,1)),X_plot]y_pred_closed=X_plot_b @ theta_best plt.scatter(X,y,s=40,alpha=0.6,label="数据")plt.plot(X_plot,y_pred_closed,"r-",linewidth=3,label="闭式解拟合")plt.legend()plt.title(f"闭式解:w={w_closed:.3f}, b={b_closed:.3f}")plt.show()方式二:梯度下降(从零实现)
# -------------------------------# 梯度下降实现# -------------------------------defcompute_mse(y_true,y_pred):returnnp.mean((y_true-y_pred)**2)defgradient_descent(X_b,y,learning_rate=0.01,n_iterations=5000,verbose=True):m=len(y)theta=np.random.randn(2,1)# 随机初始化 [b, w]history={"loss":[],"w":[],"b":[]}foriterinrange(n_iterations):# 前向:预测y_pred=X_b @ theta# 误差error=y_pred-y# 梯度(向量化写法,非常重要!)gradients=(2/m)*(X_b.T @ error)# 形状 (2,1)# 更新参数theta=theta-learning_rate*gradients# 记录loss=compute_mse(y,y_pred)history["loss"].append(loss)history["w"].append(theta[1][0])history["b"].append(theta[0][0])ifverboseanditer%500==0:print(f"Iter{iter:5d}loss ={loss:.6f}w={theta[1][0]:.4f}b={theta[0][0]:.4f}")returntheta,history# 运行梯度下降learning_rate=0.02# 注意:学习率对收敛影响很大theta_gd,hist=gradient_descent(X_b,y,learning_rate=learning_rate,n_iterations=3000)w_gd=theta_gd[1][0]b_gd=theta_gd[0][0]print("\n梯度下降最终结果:")print(f"w ={w_gd:.4f}, b ={b_gd:.4f}")print(f"最终 MSE ={hist['loss'][-1]:.6f}")# 画损失下降曲线plt.figure(figsize=(12,4))plt.subplot(1,2,1)plt.plot(hist["loss"],color="darkblue",lw=2)plt.title("损失函数下降曲线 (MSE)")plt.xlabel("迭代次数")plt.ylabel("MSE")plt.grid(True,alpha=0.3)plt.subplot(1,2,2)plt.plot(hist["w"],label="w 参数",alpha=0.8)plt.plot(hist["b"],label="b 参数",alpha=0.8)plt.axhline(true_w,color="red",linestyle="--",label=f"真实 w={true_w}")plt.axhline(true_b,color="orange",linestyle="--",label=f"真实 b={true_b}")plt.title("参数收敛过程")plt.xlabel("迭代次数")plt.legend()plt.grid(True,alpha=0.3)plt.tight_layout()plt.show()# 画最终拟合结果y_pred_gd=X_b @ theta_gd plt.scatter(X,y,s=40,alpha=0.6,label="数据")plt.plot(X_plot,y_pred_closed,"r--",lw=2,label="闭式解")plt.plot(X_plot,X_plot_b @ theta_gd,"g-",lw=3,label="梯度下降")plt.legend()plt.title(f"梯度下降结果:w={w_gd:.3f}, b={b_gd:.3f}")plt.show()总结对比表(面试/理解常用)
| 方法 | 优点 | 缺点 | 适用场景 | 代码复杂度 |
|---|---|---|---|---|
| 闭式解 | 一步得到全局最优解 无超参数 | 计算矩阵逆 O(n³) 大数据崩 | 数据量 < 10,000 | ★☆☆☆☆ |
| 梯度下降 | 可处理海量数据 可在线学习 | 需要调学习率 可能陷入局部最优 | 大数据、深度学习基础 | ★★★☆☆ |
| 随机梯度下降(SGD) | 更快、更省内存 | 噪声大、震荡 | 百万级以上数据 | ★★★★☆ |
进阶练习方向(推荐自己实现)
- 加入L2 正则化(Ridge 回归)
- 实现Mini-batch 梯度下降或SGD
- 加入早停(early stopping)机制
- 用学习率衰减(learning rate decay / scheduler)
- 扩展到多特征(housing price prediction)
- 对比自己实现的版本与
sklearn.linear_model.LinearRegression
有哪一部分想再深入?
例如:
- 加正则化后的代码
- Mini-batch 版本
- 真实房价数据集实战
- 为什么学习率太大会爆炸?
随时告诉我~