Q-VAE:量子变分自编码器#

本教程演示如何训练和评估量子变分自编码器(Quantum Variational Autoencoder, Q-VAE)模型。Q-VAE 结合了变分自编码器和量子玻尔兹曼机,能够实现更强大的生成和表征学习能力。

目标#

  • 理解 Q-VAE 的架构和工作原理

  • 在 MNIST 数据集上训练 Q-VAE

  • 进行图像重建和生成

  • 使用 Q-VAE 进行表征学习和分类

  • 使用 t-SNE 可视化潜在空间

运行环境#

示例位置: example/qvae_mnist/

  • train_qvae.ipynb: 训练 Q-VAE 模型

  • train_qvae_classifier.ipynb: 表征学习与分类

依赖项:

pip install torchvision==0.22.0 torchmetrics[image]

1、QVAE 原理概括#

QVAE(Quantum Variational Autoencoder)是一种将 量子生成模型 引入 变分自编码器 (VAE) 潜空间的生成模型。其核心思想是:

> 用量子玻尔兹曼机(QBM)替代传统 VAE 中的先验分布,从而构建一个具有量子生成能力的潜变量模型。

模型结构#

QVAE 包括以下关键组件:

  1. 编码器(Encoder) 将输入数据 \mathbf{x} 映射为潜变量的近似后验分布 q_\phi(\mathbf{z}|\mathbf{x}),通常由神经网络参数化。

  2. 先验分布(Prior) 使用 量子玻尔兹曼机 (QBM) 建模潜变量 \mathbf{z} 的先验分布。哈密顿量为:

    \mathcal{H}_\theta = \sum_l \Gamma_l \sigma_l^x + \sum_l h_l \sigma_l^z + \sum_{l<m} W_{lm} \sigma_l^z \sigma_m^z

  3. 解码器(Decoder) 将潜变量 \mathbf{z} (或其连续松弛变量 \boldsymbol{\zeta} )映射回数据空间,并使用解码器重建原始数据:

    p_\theta(\mathbf{x} | \boldsymbol{\zeta}) \sim \text{Bernoulli}(f_\theta(\boldsymbol{\zeta}))

训练目标:Q-ELBO#

QVAE 使用一个 量子下界 (Q-ELBO) 来近似最大化对数似然:

\mathcal{L}_{\text{Q-ELBO}} = \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})} [\log p_\theta(\mathbf{x} | \boldsymbol{\zeta})] - \tilde{H}(q_\phi(\mathbf{z}|\mathbf{x}) \| p_\theta(\mathbf{z}))

QBM 采样与训练#

  • 正相(positive phase):从编码器采样 \mathbf{z} \sim q_\phi(\mathbf{z}|\mathbf{x})

  • 负相(negative phase):从 QBM 中采样 \mathbf{z} \sim p_\theta(\mathbf{z}),使用 蒙特卡洛方法量子退火器

把能量作为目标函数,objective 的梯度即为基于正相和负相采样计算的梯度。

2. 模型架构#

定义了用于自编码器架构的 EncoderDecoder 两个模块,均继承自 nn.Module

两者结构对称:包含一个全连接层、层归一化(LayerNorm)和双曲正切激活函数,并支持通过 L2 权重衰减进行正则化。 编码器将高维输入映射到低维潜在空间,而解码器尝试从潜在表示重构原始输入。 每个模块提供 get_weight_decay 方法,用于在训练损失中显式加入权重正则项,以提升模型泛化能力并防止过拟合。

2.1 编码器#

class Encoder(nn.Module):

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        latent_dim: int,
        weight_decay: float = 0.01,
    ) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.weight_decay = weight_decay
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.norm1(x)
        x = F.tanh(x)
        x = self.fc2(x)
        return x

    def get_weight_decay(self) -> torch.Tensor:
        """计算权重的L2正则化损失

        对权重矩阵施加L2正则化可以提高模型的泛化能力。

        Returns:
            torch.Tensor: L2正则化损失值
        """
        return self.weight_decay * (
            torch.sum(self.fc1.weight**2) + torch.sum(self.fc2.weight**2)
        )

2.2 解码器#

class Decoder(nn.Module):

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        latent_dim: int,
        weight_decay: float = 0.01,
    ) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.weight_decay = weight_decay
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, latent_dim)

    def forward(self, z):
        z = self.fc1(z)
        z = self.norm1(z)
        z = F.tanh(z)
        z = self.fc2(z)

        return z

    def get_weight_decay(self) -> torch.Tensor:
        """计算权重的L2正则化损失

        对权重矩阵施加L2正则化可以提高模型的泛化能力。

        Returns:
            torch.Tensor: L2正则化损失值
        """
        return self.weight_decay * (
            torch.sum(self.fc1.weight**2) + torch.sum(self.fc2.weight**2)
        )

2.3 Q-VAE 完整模型#

参考模块手册中的QVAE类。

3. 数据准备#

该函数封装了 MNIST 数据集的加载与预处理流程,返回训练和测试用的 DataLoader

数据通过 ToTensor 转换为张量,并利用自定义的 flatten_tensor 将 28×28 图像展平为 784 维向量,适配全连接网络输入。 训练加载器启用打乱(shuffle),而测试加载器保持顺序以确保评估一致性。

def setup_data_loaders(root, download=True, batch_size=256, use_cuda=False):
   """
   设置MNIST数据集的数据加载器

   Args:
      root (str): 数据存储根目录
      download (bool): 如果数据不存在是否下载,默认为True
      batch_size (int): 每个批次的样本数量,默认为128
      use_cuda (bool): 是否使用GPU,决定是否启用pin_memory优化

   Returns:
      tuple: (train_loader, test_loader) 训练和测试数据加载器
   """
   # 数据预处理
   transform = transforms.Compose([
      transforms.ToTensor(),             # 转换为Tensor
      transforms.Lambda(flatten_tensor)  # 展平:将28x28图像展平成784维向量
      # 等效于:x.reshape(-1) 或 x.flatten()
   ])

   # 加载训练集
   train_set = datasets.MNIST(
      root=root,           # 数据存储路径
      train=True,          # 加载训练集(共60000个样本)
      transform=transform, # 应用定义的数据变换
      download=download    # 如果数据不存在则自动下载
   )

   # 加载测试集
   test_set = datasets.MNIST(
      root=root,           # 数据存储路径
      train=False,         # 加载测试集(共10000个样本)
      transform=transform  # 应用相同的数据变换
   )

   # 数据加载器配置参数
   # 根据是否使用GPU选择不同的优化参数
   # 将num_workers设为0避免多进程问题
   kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {'num_workers': 0}

   # 创建训练数据加载器
   train_loader = DataLoader(
      dataset=train_set,     # 训练数据集
      batch_size=batch_size, # 每个批次的样本数
      shuffle=True,          # 每个epoch打乱数据顺序,防止模型记忆顺序
      **kwargs               # 解包上述配置参数
   )

   # 创建测试数据加载器
   test_loader = DataLoader(
      dataset=test_set,      # 测试数据集
      batch_size=batch_size, # 批次大小(通常与训练集相同)
      shuffle=False,         # 测试集不需要打乱,保证可重复性
      **kwargs               # 解包配置参数
   )

   return train_loader, test_loader

4. 模型训练#

该函数实现了量子变分自编码器(Q-VAE)在 MNIST 数据集上的完整训练流程。 模型结合了经典神经网络编码器/解码器与受限玻尔兹曼机(RBM),通过最小化带权重衰减的负 ELBO 损失进行优化,并引入 KL 散度控制潜在分布与先验的对齐程度。 训练过程中记录各项损失指标并定期保存至文件。

def train_qvae(
    train_loader,  # 用于训练QVAE
    device,
    input_dim=784,  # 图片拉伸后的维度
    hidden_dim=512,  # fc1压缩后的维度
    latent_dim=256,  # 隐变量维度, num_visible + num_hidden
    num_var1=128,  # RBM可见层维度
    num_var2=128,  # RBM藏层维度
    dist_beta=10,  # 重叠分布的beta
    weight_decay=0.01,
    batch_size=256,
    epochs=20,
    lr=1e-3,
    kl_beta=0.000001,
    save_path="./models/",
    sampler_type="sa",
):
    # 创建模型
    model, optimizer = create_model(
        train_loader,
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        latent_dim=latent_dim,
        weight_decay=weight_decay,
        dist_beta=dist_beta,
        num_var1=num_var1,
        num_var2=num_var2,
        lr=lr,
        device=device,
        sampler_type=sampler_type,
    )

    # 训练循环
    loss_history = []
    elbo_history = []
    kl_history = []
    cost_history = []

    model.train()  # 设置模型为训练模式
    for epoch in tqdm(range(1, epochs + 1), desc="Training QVAE"):  # 遍历每个训练轮次
        # 训练一个epoch
        avg_loss, avg_elbo, avg_kl, avg_cost = _train_epoch(
            model, train_loader, optimizer, kl_beta, device
        )

        # 记录历史指标
        loss_history.append(avg_loss)
        elbo_history.append(avg_elbo)
        kl_history.append(avg_kl)
        cost_history.append(avg_cost)

        save_list_to_txt(os.path.join(save_path, "loss_history.txt"), loss_history)
        save_list_to_txt(os.path.join(save_path, "elbo_history.txt"), elbo_history)
        save_list_to_txt(os.path.join(save_path, "cost_history.txt"), cost_history)
        save_list_to_txt(os.path.join(save_path, "kl_history.txt"), kl_history)

        # # 保存当前轮次的模型参数
        # model_save_path = os.path.join(save_path, f'davepp_epoch{epoch}.pth')
        # torch.save(model.state_dict(), model_save_path)

        # 打印本轮训练结果
        print(
            f"Epoch {epoch}/{epochs}: "
            f"Loss: {avg_loss:.4f}, "
            f"elbo: {avg_elbo:.4f}, "
            f"KL: {avg_kl:.4f}, "
            f"Cost: {avg_cost:.4f}"
        )

    # 保存模型
    model_save_path = os.path.join(save_path, f"qvae_mnist.pth")
    torch.save(model.state_dict(), model_save_path)
    return model

5. 可视化与评估#

本节提供两类关键可视化工具: 一是通过 plot_training_curves 绘制训练/验证损失与准确率曲线,用于监控模型收敛情况; 二是利用 t_SNE 对 QVAE 模型提取的潜在表示进行降维可视化,揭示不同类别在隐空间中的分布结构。 两者均支持自动保存高分辨率图像,并可灵活控制是否实时显示,便于实验分析、结果记录。

5.1 训练过程可视化#

def plot_training_curves(
    train_loss_history,
    val_loss_history,
    train_acc_history,
    val_acc_history,
    save_path=None,
    show=True,
):
    """
    绘制训练和验证的损失及准确率曲线

    Args:
        train_loss_history: 训练损失历史
        val_loss_history: 验证损失历史
        train_acc_history: 训练准确率历史
        val_acc_history: 验证准确率历史
        save_path: 图像保存路径
    """
    plt.figure(figsize=(12, 5))

    # 损失曲线
    plt.subplot(1, 2, 1)
    plt.plot(train_loss_history, label="Training Loss", color="blue", alpha=0.7)
    plt.plot(val_loss_history, label="Validation Loss", color="red", alpha=0.7)
    plt.title("Training and Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)

    # 准确率曲线
    plt.subplot(1, 2, 2)
    plt.plot(train_acc_history, label="Training Accuracy", color="blue", alpha=0.7)
    plt.plot(val_acc_history, label="Validation Accuracy", color="red", alpha=0.7)
    plt.title("Training and Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    # plt.close()

    # 自动保存
    if save_path is None:
        # 生成默认保存路径
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        save_path = f"results/mlp_training_curves_{timestamp}.png"

    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    print(f"Training curves saved to: {save_path}")
    plt.show()

    if show:
        plt.show()
    else:
        plt.close()  # 不显示时关闭图像,节省内存

5.3 潜在空间可视化#

def t_SNE(
    test_loader,
    qvae_model,
    use_std=True,
    point_size=20,
    alpha=0.6,
    epochs=None,
    save_path=None,
    show=True,
):
    """
    QVAE版本的t-SNE可视化

    Args:
        test_loader: 测试数据加载器
        qvae_model: QVAE模型
        use_std: 是否使用标准差
        point_size: 点大小
        alpha: 透明度
        epochs: 训练轮数
        save_path: 保存路径
        show: 是否显示图像
    """
    features = []
    labels = []

    qvae_model.eval()
    device = next(qvae_model.parameters()).device

    with torch.no_grad():
        for batch_idx, (example_data, example_targets) in enumerate(test_loader):
            example_data = example_data.to(device)

            # QVAE前向传播 - 获取潜变量zeta
            _, _, _, zeta = qvae_model(example_data)

            zeta_np = zeta.cpu().numpy()

            for idx in range(zeta_np.shape[0]):
                features.append(zeta_np[idx])
                labels.append(example_targets[idx].item())

    # 创建DataFrame
    feat_cols = [f"dim_{i}" for i in range(zeta_np.shape[1])]
    df = pd.DataFrame(features, columns=feat_cols)
    df["label"] = labels
    df["label"] = df["label"].apply(lambda i: str(i))

    print(f"Extracted {len(features)} samples with {zeta_np.shape[1]} dimensions")

    # 执行t-SNE
    print("Running t-SNE...")
    tsne = TSNE(n_components=2, verbose=1, perplexity=30, max_iter=300, random_state=42)
    tsne_results = tsne.fit_transform(df[feat_cols].values)

    df_tsne = df.copy()
    df_tsne["x-tsne"] = tsne_results[:, 0]
    df_tsne["y-tsne"] = tsne_results[:, 1]

    # 可视化
    # plt.figure(figsize=(10, 8))
    # scatter = plt.scatter(df_tsne['x-tsne'], df_tsne['y-tsne'],
    #                      c=df_tsne['label'].astype(int),
    #                      cmap='tab10', s=point_size, alpha=alpha)
    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(
        df_tsne["x-tsne"],
        df_tsne["y-tsne"],
        c=df_tsne["label"].astype(int),
        cmap="tab10",
        s=point_size,
        alpha=alpha,
    )

    # 添加颜色条
    cbar = plt.colorbar(scatter, ax=ax, label="Digit")

    # 动态标题和文件名
    training_status = "fully_trained" if epochs and epochs >= 20 else f"epochs_{epochs}"
    title = f"t-SNE Visualization of QVAE Latent Space ({training_status})"
    plt.title(title)
    plt.xlabel("t-SNE dimension 1")
    plt.ylabel("t-SNE dimension 2")
    plt.tight_layout()

    # 自动保存
    if save_path is None:
        # 生成默认保存路径
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        save_path = f"results/t-SNE_QVAE_{training_status}_{timestamp}.png"

    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    print(f"t-SNE plot saved to: {save_path}")
    plt.show()

    if show:
        plt.show()
    else:
        plt.close()  # 不显示时关闭图像,节省内存

    return df_tsne, save_path, training_status

6. 表征学习与分类#

Q-VAE 学到的表征可用于下游分类任务:

该函数 train_mlp_classifier 用于训练一个多层感知机(MLP)分类器,输入特征是通过QVAE模型提取的数据表征。 它首先将数据集划分为训练集和验证集,并初始化一个MLP模型、优化器和损失函数。 在每个训练周期,模型参数根据训练集更新,并在验证集上评估性能。

def train_mlp_classifier(
    features,
    labels,
    device,
    epochs=100,
    lr=1e-3,
    weight_decay=1e-4,
    batch_size=64,
    seed=42,
    smoke_test=False,
    show=True,
    save_path="./models/",
):
    # 数据分割
    X_train, X_val, y_train, y_val = train_test_split(
        features.numpy(), labels.numpy(), test_size=0.4, random_state=seed
    )

    X_train = torch.FloatTensor(X_train)
    y_train = torch.LongTensor(y_train)
    X_val = torch.FloatTensor(X_val)
    y_val = torch.LongTensor(y_val)

    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # MLP模型
    mlp = MLP(input_dim=features.shape[1], output_dim=10).to(device)
    optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    # 记录训练历史
    train_loss_history = []
    val_loss_history = []
    train_acc_history = []
    val_acc_history = []

    # 训练循环
    best_acc = 0
    for epoch in tqdm(range(1, epochs + 1), desc="Training MLP"):
        # 训练阶段
        train_acc, avg_train_loss = _train_mlp_epoch(
            model=mlp,
            data_loader=train_loader,
            optimizer=optimizer,
            criterion=criterion,
            device=device,
        )

        # 验证阶段
        val_acc, avg_val_loss = _eval_mlp_epoch(
            model=mlp, data_loader=val_loader, criterion=criterion, device=device
        )

        # 记录历史
        train_loss_history.append(avg_train_loss)
        val_loss_history.append(avg_val_loss)
        train_acc_history.append(train_acc)
        val_acc_history.append(val_acc)

        # 打印训练和验证指标
        if epoch % 5 == 0:
            print(
                f"Epoch {epoch:3d}: "
                f"Train Loss: {avg_train_loss:.4f}, "
                f"Val Loss: {avg_val_loss:.4f}, "  # 新增验证损失
                f"Train Acc: {train_acc:.2f}%, "
                f"Val Acc: {val_acc:.2f}%"
            )

        if val_acc > best_acc:
            best_acc = val_acc
            model_save_path = os.path.join(save_path, "best_mlp_classifier.pth")
            torch.save(mlp.state_dict(), model_save_path)

    print(f"Best Validation Accuracy: {best_acc:.2f}%")

    # 绘制训练曲线
    curves_save_path = ""
    if not smoke_test:
        curves_save_path = os.path.join(
            save_path, f"mlp_training_curves_epochs_{epochs}.png"
        )
        plot_training_curves(
            train_loss_history=train_loss_history,
            val_loss_history=val_loss_history,
            train_acc_history=train_acc_history,
            val_acc_history=val_acc_history,
            save_path=curves_save_path,
            show=show,
        )
    return mlp, best_acc, curves_save_path

7. 科研应用:QBM-VAE#

Q-VAE 的进阶版本 QBM-VAE 在科研中展示了重要价值:

单细胞转录组学分析

  • 显著提升聚类精度

  • 检测传统方法无法辨识的新型细胞亚型

  • 为靶点发现提供新线索

相关论文Quantum-Boosted High-Fidelity Deep Learning