BP神经网络的核心原理
BP(Backpropagation)神经网络是一种基于误差反向传播算法训练的多层前馈神经网络,由输入层、隐藏层和输出层构成,其核心思想是通过计算预测值与真实值的误差,反向传播调整网络权重,逐步降低损失函数,实现模型优化,以下是关键步骤:
Python实现BP神经网络的完整代码
以下代码以鸢尾花分类任务为例,展示BP神经网络的标准实现流程:
import numpy as np from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.preprocessing import OneHotEncoder class BPNeuralNetwork: def __init__(self, input_size, hidden_size, output_size): self.W1 = np.random.randn(input_size, hidden_size) * 0.01 # 输入层到隐藏层权重 self.b1 = np.zeros((1, hidden_size)) # 隐藏层偏置 self.W2 = np.random.randn(hidden_size, output_size) * 0.01 self.b2 = np.zeros((1, output_size)) def sigmoid(self, x): return 1 / (1 + np.exp(-x)) def sigmoid_derivative(self, x): return x * (1 - x) def forward(self, X): self.z1 = np.dot(X, self.W1) + self.b1 self.a1 = self.sigmoid(self.z1) # 隐藏层激活值 self.z2 = np.dot(self.a1, self.W2) + self.b2 self.a2 = self.sigmoid(self.z2) # 输出层结果 return self.a2 def backward(self, X, y, lr=0.01): m = X.shape[0] # 计算输出层误差 a2_error = self.a2 - y a2_delta = a2_error * self.sigmoid_derivative(self.a2) # 计算隐藏层误差 a1_error = np.dot(a2_delta, self.W2.T) a1_delta = a1_error * self.sigmoid_derivative(self.a1) # 更新权重 self.W2 -= lr * np.dot(self.a1.T, a2_delta) / m self.b2 -= lr * np.sum(a2_delta, axis=0, keepdims=True) / m self.W1 -= lr * np.dot(X.T, a1_delta) / m self.b1 -= lr * np.sum(a1_delta, axis=0, keepdims=True) / m def train(self, X, y, epochs=1000): for _ in range(epochs): output = self.forward(X) self.backward(X, y) # 数据准备(以鸢尾花数据集为例) iris = load_iris() X = iris.data y = iris.target.reshape(-1, 1) encoder = OneHotEncoder(sparse=False) y_onehot = encoder.fit_transform(y) X_train, X_test, y_train, y_test = train_test_split(X, y_onehot, test_size=0.2) # 训练模型 model = BPNeuralNetwork(input_size=4, hidden_size=5, output_size=3) model.train(X_train, y_train, epochs=5000) # 测试准确率 predictions = np.argmax(model.forward(X_test), axis=1) true_labels = np.argmax(y_test, axis=1) accuracy = np.mean(predictions == true_labels) print(f"测试集准确率: {accuracy * 100:.2f}%")
代码解析与优化建议
lr
)过大可能导致震荡,过小收敛慢,建议动态调整或使用Adam优化器。 高频问题解答
参考文献与扩展阅读