RBM 分类:手写数字识别#
本教程演示如何使用受限玻尔兹曼机(RBM)在手写数字数据集上进行特征学习与分类。适合初学者理解 RBM 在图像特征提取与分类中的应用流程。
教程目标#
完成本教程后,您将学会:
使用 RBM 从图像数据中提取特征
将 RBM 作为特征提取器与传统分类器结合
可视化 RBM 学到的权重和生成的样本
评估分类模型的性能
运行环境#
示例位置: example/rbm_digits/rbm_digits.ipynb
依赖项:
pip install scikit-learn matplotlib scipy
1. 数据准备#
此部分将展示如何载入经典的数字手写体数据集,并通过简单的数据增强技术(如向四个方向平移图像)来扩展该数据集,以提高模型的泛化能力。
首先,我们使用`load_digits`从`sklearn`库中加载数据,获取8x8像素的手写数字图像及其标签。 接着,对每个原始图像向上下左右四个方向平移,创建更多的训练样本。 最后,将所有图像数据展平为二维数组,并划分成训练集和测试集,同时进行归一化处理,确保各特征值位于[0, 1]区间内,为后续模型训练做好准备。
def load_data(self, plot_img=False):
"载入图片数据"
digits = load_digits()
images = digits.images # 8x8 的图像矩阵
labels = digits.target # 对应的标签
# 获取图像数据和标签
# 扩展数据集
expanded_images = []
expanded_labels = []
for image, label in zip(images, labels):
# 原始图像
expanded_images.append(image)
expanded_labels.append(label)
# 向四个方向平移
for direction in ["up", "down", "left", "right"]:
translated_image = self.translate_image(image, direction)
expanded_images.append(translated_image)
expanded_labels.append(label)
# 将列表转换为 NumPy 数组
expanded_images = np.array(expanded_images)
expanded_labels = np.array(expanded_labels)
# 可视化图像数据和标签
if plot_img:
plt.figure(figsize=(16, 9))
for index in range(5):
plt.subplot(1, 5, index + 1)
plt.imshow(expanded_images[index], origin="lower", cmap="gray")
plt.title("Training: %i\n" % expanded_labels[index], fontsize=18)
# 将图像数据展平为二维数组 (n_samples, 64)
n_samples = expanded_images.shape[0]
data = expanded_images.reshape((n_samples, -1))
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
data, expanded_labels, test_size=0.2, random_state=42
)
# 使用sklearn的MinMaxScaler进行归一化
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
return X_train, X_test, y_train, y_test
2. 模型训练#
此部分展示了如何训练一个受限玻尔兹曼机(RBM)模型。
首先,基于输入数据的特征维度和预设的隐层单元数初始化RBM模型,并将模型放置在指定设备上(如CPU或GPU)。 接着,定义了随机梯度下降(SGD)优化器用于更新模型参数。整个训练过程采用小批量梯度下降法逐步调整模型参数以最小化目标函数, 该函数包括了负对数似然估计以及权重和偏置的衰减项。
在训练过程中,如果启用了`verbose`模式,则会定期输出当前迭代的目标函数值及模型参数统计信息, 并可视化生成的样本图像和模型参数,便于实时监控模型的学习进度和效果。
def fit(self, X, y=None): # 修改接口以符合scikit-learn约定
"""
训练RBM模型
Args:
X: 训练数据,形状为 (n_samples, n_features)
y: 忽略,为兼容scikit-learn接口
"""
# 初始化受限玻尔兹曼机(RBM)模型
rbm = RestrictedBoltzmannMachine(
X.shape[1], # 可见层单元数(特征维度)
self.n_components, # 隐层单元数
)
rbm.to(self.device) # 将模型移动到指定设备(CPU/GPU)
self.rbm = rbm
# 初始化优化器
opt_rbm = SGD(rbm.parameters(), lr=self.learning_rate)
n_samples = X.shape[0] # 样本数量
n_batches = int(np.ceil(float(n_samples) / self.batch_size)) # 批次数量
# 生成每个batch的切片索引
batch_slices = list(
gen_even_slices(n_batches * self.batch_size, n_batches, n_samples=n_samples)
)
X_torch = torch.FloatTensor(X).to(self.device) # 转为torch张量并移动到设备
idx = 0
# 训练循环
for iteration in range(1, self.n_iter + 1):
for step, batch_slice in enumerate(batch_slices):
idx += 1
x = X_torch[batch_slice] # 获取当前batch数据
x = rbm.get_hidden(x) # 正相(计算隐层激活)
s = rbm.sample(self.sampler) # 负相(采样重构数据)
# s = rbm.get_visible(x[:, rbm.num_visible :]) # 使用隐藏层重构可见层
opt_rbm.zero_grad() # 梯度清零
# 计算目标函数(等价于负对数似然),并加权衰减项
w_weight_decay = 0.02 * torch.sum(rbm.quadratic_coef**2) # 权重衰减
b_weight_decay = 0.05 * torch.sum(rbm.linear_bias**2) # 偏置衰减
objective = rbm.objective(x, s) + w_weight_decay + b_weight_decay
# 反向传播并更新参数
objective.backward()
opt_rbm.step()
# 如果verbose,定期评估模型性能和可视化参数
if self.verbose:
print(f"Iteration {idx}, Objective: {objective.item():.6f}")
if (idx - 1) % 20 == 0:
# 打印权重和偏置的均值与最大值
print(
f"jmean {torch.abs(rbm.quadratic_coef).mean()}"
f" jmax {torch.abs(rbm.quadratic_coef).max()}"
)
print(
f"hmean {torch.abs(rbm.linear_bias).mean()}"
f" hmax {torch.abs(rbm.linear_bias).max()}"
)
if self.plot_img:
display_samples = (
rbm.sample(self.sampler)
.cpu()
.numpy()[:20, : rbm.num_visible]
)
# 生成样本
plt.figure(figsize=(16, 2))
plt.imshow(self.gen_digits_image(display_samples, 8))
plt.title(f"Generated samples at iteration {iteration}")
plt.show()
_, axes = plt.subplots(1, 2)
axes[0].imshow(rbm.quadratic_coef.detach().cpu().numpy())
axes[1].imshow(
rbm.quadratic_coef.grad.detach().cpu().numpy()
)
plt.tight_layout()
plt.show()
return self
3. 特征提取与分类#
3.1 训练分类器#
本节展示了如何构建并训练一个由受限玻尔兹曼机(RBM)和逻辑回归组成的两阶段分类流水线,并将其与仅使用原始像素特征的逻辑回归模型进行性能对比。
首先,通过 RBMRunner 加载并预处理手写数字数据集,随后将 RBM 作为特征提取器嵌入到 Pipeline 中,再接上逻辑回归分类器。
训练过程中分别记录两种方法的耗时,并在测试集上评估其准确率与分类报告。
该实验旨在验证 RBM 学习到的高层特征是否能提升下游分类器的性能,同时也为理解深度学习中无监督预训练的作用提供直观示例。
def train_classifier(n_iter=2, use_cim=False):
logistic = LogisticRegression(random_state=42)
# 初始化RBM
rbm = RBMRunner(
n_components=128,
learning_rate=0.1,
batch_size=32,
n_iter=n_iter,
verbose=True,
plot_img=False,
random_state=seed,
use_cim=use_cim,
)
# 加载数据
X_train, X_test, y_train, y_test = rbm.load_data(plot_img=True)
classifier = Pipeline(steps=[("rbm", rbm), ("logistic", logistic)])
########## 训练模型 ##########
logistic.C = 500.0
logistic.max_iter = 1000
# 训练 RBM-Logistic Pipeline
start_time = time.time()
classifier.fit(X_train, y_train)
training_time = time.time() - start_time
print(f"RBM Pipline training completed in {training_time:.2f} seconds")
# 训练 Logistic regression
logistic_classifier = LogisticRegression(C=500.0, max_iter=1000, random_state=42)
start_time = time.time()
logistic_classifier.fit(X_train, y_train)
training_time = time.time() - start_time
print(f"Logistic regression training completed in {training_time:.2f} seconds")
########## 评估模型 ##########
pip_pred = classifier.predict(X_test)
pip_acc = accuracy_score(y_test, pip_pred)
print(
"\nLogistic regression using RBM features:\n%s\n"
% (classification_report(y_test, pip_pred))
)
print(f"Test Accuracy: {pip_acc:.4f}")
log_pred = logistic_classifier.predict(X_test)
log_acc = accuracy_score(y_test, log_pred)
print(
"\nLogistic regression using raw pixel features:\n%s\n"
% (classification_report(y_test, log_pred))
)
print(f"Test Accuracy: {log_acc:.4f}")
return rbm, y_test, log_pred
3.2 可视化权重#
该函数将RBM模型中每个隐单元对应的权重以8×8图像形式可视化,直观展示其从数据中学到的特征模式。支持将结果保存为高分辨率PDF文件。
def plot_weights(self, save_as="qbm_weights", save_pdf=False):
"""绘制权重"""
weights = self.rbm.quadratic_coef.detach().cpu().numpy()
fig, axes = plt.subplots(
8, 16, gridspec_kw={"wspace": 0.1, "hspace": 0.1}, figsize=(16, 7)
)
fig.suptitle(f"{self.n_components} components extracted by QBM", fontsize=16)
fig.subplots_adjust()
for i, ax in enumerate(axes.flatten()):
if i < weights.shape[1]:
ax.imshow(weights[:, i].reshape(8, 8), cmap=plt.cm.gray)
ax.axis("off")
# 保存结果
if save_pdf:
_ensure_result_dir()
plt.savefig(
f"results/{save_as}.pdf", dpi=300, bbox_inches="tight", format="pdf"
)
plt.show()
3.2 可视化混淆矩阵#
该函数使用热力图可视化模型在测试集上的混淆矩阵,清晰展示各类别之间的预测准确率与混淆情况。 支持添加自定义标题后缀,并可将结果保存为高分辨率PDF文件。
def plot_confusion_matrix(self, y_true, y_pred, title_suffix="", save_pdf=False):
"""绘制混淆矩阵"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.title(f"Confusion Matrix ({title_suffix})", fontsize=18)
plt.xlabel("Predicted Label", fontsize=16)
plt.ylabel("True Label", fontsize=16)
# plt.xticks(rotation=45)
# 保存结果
if save_pdf:
_ensure_result_dir()
plt.savefig(
f"results/rbm_confusion_matrix_{title_suffix}.pdf",
dpi=300,
bbox_inches="tight",
format="pdf",
)
plt.tight_layout()
plt.show()