……

VAE In JAX【个人记录向】

和上一篇 SAC In JAX 一样,我们用 JAX 实现 VAE,配置一样,只需要安装符合版本的 torchvision 即可,实现中提供了 tensorboard 记录以及最后的可视化展示,测试集即为最经典的 MNIST,代码如下:

import jax
import flax
import math
import optax
import numpy as np
import jax.numpy as jnp
from flax import linen as nn
from datetime import datetime
from torchvision import datasets
from flax.training import train_state
from matplotlib import pyplot as plt
from flax.training.train_state import TrainState
from stable_baselines3.common.logger import configure

def get_data():
    train_dataset = datasets.MNIST(root='./data', train=True, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False)
    img_train = train_dataset.data.numpy().reshape(-1, 784)
    train_x_min = np.min(img_train, axis=0)
    train_x_max = np.max(img_train, axis=0)

    train_x = (img_train - train_x_min) / (train_x_max - train_x_min + 1e-7)
    train_y = train_dataset.targets.numpy()
    img_test = test_dataset.data.numpy().reshape(-1, 784)
    test_x = (img_test - train_x_min) / (train_x_max - train_x_min + 1e-7)
    test_y = test_dataset.targets.numpy()
    N = train_x.shape[0]
    M = test_x.shape[0]
    return jnp.asarray(train_x), jnp.asarray(train_y), jnp.asarray(test_x), jnp.asarray(test_y), N, M

class VAE_encoder(nn.Module):
    hidden_dim: int
    latent_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        encode = nn.relu(x)
        mu = nn.Dense(self.latent_dim)(encode)
        log_sig = nn.Dense(self.latent_dim)(encode)
        return mu, log_sig

class VAE_decoder(nn.Module):
    output_dim: int
    hidden_dim: int
    @nn.compact
    def __call__(self, latent):
        x = nn.Dense(self.hidden_dim)(latent)
        x = nn.relu(x)
        x = nn.Dense(self.output_dim)(x)
        x = nn.sigmoid(x)
        return x

class VAE:
    def __init__(self, input_dim, encoder_lr, decoder_lr, epochs, batch_size, logger, key):
        self.input_dim = input_dim
        self.encoder_lr, self.encoder_lr = encoder_lr, decoder_lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.logger = logger
        self.key = key
        self.hidden_dim = 400
        self.latent_dim = 48
        self.encoder = VAE_encoder(self.hidden_dim, self.latent_dim)
        self.decoder = VAE_decoder(self.input_dim, self.hidden_dim)
        self.key, encoder_key, decoder_key = jax.random.split(self.key, 3)
        encoder_params = self.encoder.init(encoder_key, jnp.ones((self.batch_size, self.input_dim)))['params']
        decoder_params = self.decoder.init(decoder_key, jnp.ones((self.batch_size, self.latent_dim)))['params']
        encoder_optx = optax.adam(encoder_lr)
        decoder_optx = optax.adam(decoder_lr)
        self.encoder_state = TrainState.create(apply_fn=self.encoder.apply, params=encoder_params, tx=encoder_optx)
        self.decoder_state = TrainState.create(apply_fn=self.decoder.apply, params=decoder_params, tx=decoder_optx)

    @staticmethod
    @jax.jit
    def forward(x, encoder_state, decoder_state, now_key):
        mu, log_std = encoder_state.apply_fn({"params": encoder_state.params}, x)
        now_key, eps_key = jax.random.split(now_key, 2)
        eps = jax.random.normal(eps_key, shape=mu.shape)
        latent = mu + eps * jnp.exp(log_std * 0.5)
        x_ = decoder_state.apply_fn({"params": decoder_state.params}, latent)
        return x_, now_key


    @staticmethod
    @jax.jit
    def train_step(data, encoder_state, decoder_state, key):
        def loss_fn(encoder_param, decoder_param, encoder_state, decoder_state, now_key):
            mu, log_std = encoder_state.apply_fn({"params": encoder_param}, data)
            now_key, eps_key = jax.random.split(now_key, 2)
            eps = jax.random.normal(eps_key, shape=mu.shape)
            latent = mu + eps * jnp.exp(log_std * 0.5)
            x_ = decoder_state.apply_fn({"params": decoder_param}, latent)
            construction_loss = jnp.sum(jnp.sum(jnp.square(x_ - data), axis=1))
            commitment_loss = -0.5 * jnp.sum(1 + log_std - mu ** 2 - jnp.exp(log_std))
            loss = construction_loss + commitment_loss
            return loss, (construction_loss, commitment_loss, now_key)
        (loss, (construction, commitment_loss, key)), grads = jax.value_and_grad(loss_fn, has_aux=True, argnums=(0, 1))(
            encoder_state.params, decoder_state.params, encoder_state, decoder_state, key)
        encoder_state = encoder_state.apply_gradients(grads=grads[0])
        decoder_state = decoder_state.apply_gradients(grads=grads[1])

        return encoder_state, decoder_state, construction, commitment_loss, key

    def train(self, train_x, N):
        for epoch in range(self.epochs):
            self.key, permutation_key = jax.random.split(self.key, 2)
            shuffled_indices = jax.random.permutation(permutation_key, N)
            now_data = train_x[shuffled_indices, :]
            # now_data = train_x
            tot_construction_loss, tot_commitment_loss = 0, 0
            for i in range(0, N, self.batch_size):
                batch_x = now_data[i: i + self.batch_size]
                self.encoder_state, self.decoder_state, construction_loss, commitment_loss, self.key = VAE.train_step(
                    batch_x, self.encoder_state, self.decoder_state, self.key)
                tot_construction_loss += construction_loss
                tot_commitment_loss += commitment_loss

            now = datetime.now()
            time_str = now.strftime("%Y-%m-%d %H:%M:%S")
            print(f"Epoch {epoch + 1}, Construction_loss: {tot_construction_loss / N:.4f}, Commitment_loss: {tot_commitment_loss / N:.4f}! Time: {time_str}.")
            self.logger.record("Construction_loss", float(tot_construction_loss) / N)
            self.logger.record("Commitment_loss", float(tot_commitment_loss) / N)
            self.logger.dump(step=epoch + 1)

def plot(test_x, test_y, VAE_model):
    original_images = []

    for digit in range(10):
        original_image = [test_x[i] for i in range(len(test_x)) if test_y[i] == digit][111]
        original_images.append(jnp.asarray(original_image.reshape(784)))

    input = jnp.stack(original_images, axis=0)
    input = jnp.asarray(input, dtype=jnp.float32)
    output, VAE_model.key = VAE_model.forward(input, VAE_model.encoder_state, VAE_model.decoder_state, VAE_model.key)
    # output, _, __ = VAE_model(input_tensor)
    generated_images = list(np.array(jax.lax.stop_gradient(output)))

    fig, axes = plt.subplots(2, 10, figsize=(15, 3))

    for i in range(10):
        axes[0, i].imshow(original_images[i].reshape(28, 28), cmap='gray')
        axes[0, i].set_title(f"Original {i}")
        axes[0, i].axis('off')

        axes[1, i].imshow(generated_images[i].reshape(28, 28), cmap='gray')
        axes[1, i].set_title(f"Generated {i}")
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.savefig('result.png')
    plt.show()

def main():
    start_time = datetime.now().strftime('%Y%m%d_%H%M%S')
    log_path = f"logs/VAEtest_{start_time}/"
    logger = configure(log_path, ["tensorboard"])
    train_x, train_y, test_x, test_y, N, M = get_data()
    key = jax.random.PRNGKey(41)
    VAEmodel = VAE(input_dim=784, encoder_lr=0.001, decoder_lr=0.001, epochs=30, batch_size=128, logger=logger, key=key)
    VAEmodel.train(train_x, N)
    plot(test_x, test_y, VAEmodel)

if __name__ == '__main__':
    main()

实验结果

重建误差稳定下降:

image

在测试集中随机抽取(并非随机)每个数字各一个样例,展示重建后的图片,其实效果不太符合预期,应该是要调参,但是我懒得调了:

image

posted @ 2025-09-19 18:27  童话镇里的星河  阅读(7)  评论(0)    收藏  举报