[K3VAE2] - VAE, using a custom model class (MNIST dataset)¶
Construction and training of a VAE, using model subclass, with a latent space of small dimension.Objectives :¶
- Understanding and implementing a variational autoencoder neurals network (VAE)
- Understanding a still more advanced programming model, using a custom model
The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with.
...MNIST with a small scale if you haven't a GPU ;-)
What we're going to do :¶
- Defining a VAE model
- Build the model
- Train it
- Have a look on the train process
Acknowledgements :¶
Thanks to François Chollet who is at the base of this example (and the creator of Keras !!).
See : https://keras.io/examples/generative/vae
Step 1 - Init python stuff¶
import os
os.environ['KERAS_BACKEND'] = 'torch'
import keras
from keras import layers
import numpy as np
from modules.models import VAE
from modules.layers import SamplingLayer
from modules.callbacks import ImagesCallback
from modules.datagen import MNIST
import matplotlib.pyplot as plt
import scipy.stats
import sys
import fidle
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3VAE2')
VAE.about()
FIDLE - Environment initialization
Version : 2.3.0 Run id : K3VAE2 Run dir : ./run/K3VAE2 Datasets dir : /gpfswork/rech/mlh/uja62cb/fidle-project/datasets-fidle Start time : 03/03/24 21:22:39 Hostname : r6i0n6 (Linux) Tensorflow log level : Warning + Error (=1) Update keras cache : False Update torch cache : False Save figs : ./run/K3VAE2/figs (True) keras : 3.0.4 numpy : 1.24.4 sklearn : 1.3.2 yaml : 6.0.1 skimage : 0.22.0 matplotlib : 3.8.2 pandas : 2.1.3 torch : 2.1.1
FIDLE 2024 - VAE
Version : 2.0 Keras version : 3.0.4
Step 2 - Parameters¶
scale
: with scale=1, we need 1'30s on a GPU V100 ...and >20' on a CPU !
latent_dim
: 2 dimensions is small, but usefull to draw !
fit_verbosity
: Verbosity of training progress bar: 0=silent, 1=progress bar, 2=One line
loss_weights
: Our loss function is the weighted sum of two loss:
r_loss
which measures the loss during reconstruction.kl_loss
which measures the dispersion.
The weights are defined by: loss_weights=[k1,k2]
where : total_loss = k1*r_loss + k2*kl_loss
In practice, a value of [1,.06] gives good results here.
latent_dim = 6
loss_weights = [1,.06]
scale = .2
seed = 123
batch_size = 64
epochs = 4
fit_verbosity = 1
Override parameters (batch mode) - Just forget this cell
fidle.override('latent_dim', 'loss_weights', 'scale', 'seed', 'batch_size', 'epochs', 'fit_verbosity')
** Overrided parameters : ** scale : 1 epochs : 20
Step 3 - Prepare data¶
MNIST.get_data()
return : x_train,y_train, x_test,y_test
,
but we only need x_train for our training.
x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )
fidle.scrawler.images(x_data[:20], None, indices='all', columns=10, x_size=1,y_size=1,y_padding=0, save_as='01-original')
Seeded (123) Dataset loaded. Concatenated. Shuffled. rescaled (1). Normalized. Reshaped. splited (1). x_train shape is : (70000, 28, 28, 1) x_test shape is : (0, 28, 28, 1) y_train shape is : (70000,) y_test shape is : (0,) Blake2b digest is : 0c903710d4d28b01c174
Step 4 - Build model¶
In this example, we will use a custom model. For this, we will use :
SamplingLayer
, which generates a vector z from the parameters z_mean and z_log_var - See : SamplingLayer.pyVAE
, a custom model with a specific train_step - See : VAE.py
Encoder¶
inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, strides=1, padding="same", activation="relu")(inputs)
x = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2D(64, 3, strides=1, padding="same", activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = SamplingLayer()([z_mean, z_log_var])
encoder = keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
encoder.compile()
Decoder¶
inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, strides=1, padding="same", activation="relu")(x)
x = layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu")(x)
outputs = layers.Conv2DTranspose(1, 3, padding="same", activation="sigmoid")(x)
decoder = keras.Model(inputs, outputs, name="decoder")
decoder.compile()
vae = VAE(encoder, decoder, loss_weights)
vae.compile(optimizer='adam')
Fidle VAE is ready :-) loss_weights=[1, 0.06]
Step 5 - Train¶
5.1 - Using two nice custom callbacks :-)¶
Two custom callbacks are used:
ImagesCallback
: qui va sauvegarder des images durant l'apprentissage - See ImagesCallback.pyBestModelCallback
: qui sauvegardera le meilleur model - See BestModelCallback.py
callback_images = ImagesCallback(x=x_data, z_dim=latent_dim, nb_images=5, from_z=True, from_random=True, run_dir=run_dir)
callbacks_list = [callback_images]
5.2 - Let's train !¶
With scale=1
, need 1'15 on a GPU (V100 at IDRIS) ...or 20' on a CPU
chrono=fidle.Chrono()
chrono.start()
history = vae.fit(x_data, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list, verbose=fit_verbosity)
chrono.show()
Epoch 1/20
/gpfswork/rech/mlh/uja62cb/local/fidle-k3/lib/python3.11/site-packages/keras/src/backend/common/backend_utils.py:88: UserWarning: You might experience inconsistencies accross backends when calling conv transpose with kernel_size=3, stride=2, dilation_rate=1, padding=same, output_padding=1. warnings.warn(
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 12622.9893 Epoch 2/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 7342.2124 Epoch 3/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 6763.4082 Epoch 4/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 6521.6929 Epoch 5/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 15s 14ms/step - loss: 6389.1709 Epoch 6/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 6298.7036 Epoch 7/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 6222.0166 Epoch 8/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 6165.6001 Epoch 9/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 6117.1689 Epoch 10/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 6080.5181 Epoch 11/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 6042.1147 Epoch 12/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 6027.4399 Epoch 13/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 6000.8267 Epoch 14/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 5968.7407 Epoch 15/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 5937.7476 Epoch 16/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 5908.4883 Epoch 17/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 16s 14ms/step - loss: 5900.8203 Epoch 18/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 15s 14ms/step - loss: 5890.2422 Epoch 19/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 15s 14ms/step - loss: 5873.6401 Epoch 20/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 15s 14ms/step - loss: 5844.3433 Duration : 312.42 seconds
fidle.scrawler.history(history, plot={"Loss":['loss']}, save_as='history')
6.2 - Reconstruction during training¶
At the end of each epoch, our callback saved some reconstructed images.
Where :
Original image -> encoder -> z -> decoder -> Reconstructed image
images_z, images_r = callback_images.get_images( range(0,epochs,2) )
fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-original')
fidle.utils.subtitle('Encoded/decoded images')
fidle.scrawler.images(images_z, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-reconstruct')
fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)
Original images :
Encoded/decoded images
Original images :
6.3 - Generation (latent -> decoder) during training¶
fidle.utils.subtitle('Generated images from latent space')
fidle.scrawler.images(images_r, None, indices='all', columns=5, x_size=2,y_size=2, save_as='04-encoded')
Generated images from latent space
6.4 - Save model¶
os.makedirs(f'{run_dir}/models', exist_ok=True)
vae.save(f'{run_dir}/models/vae_model.keras')
Step 7 - Model evaluation¶
7.1 - Reload model¶
vae=VAE()
vae.reload(f'{run_dir}/models/vae_model.keras')
Fidle VAE is ready :-) loss_weights=[1, 1] Reloaded.
/gpfswork/rech/mlh/uja62cb/local/fidle-k3/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:394: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 16 variables whereas the saved optimizer has 2 variables. trackable.load_own_variables(weights_store.get(inner_path)) /gpfswork/rech/mlh/uja62cb/local/fidle-k3/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:394: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 12 variables whereas the saved optimizer has 2 variables. trackable.load_own_variables(weights_store.get(inner_path))
7.2 - Image reconstruction¶
# ---- Select few images
x_show = fidle.utils.pick_dataset(x_data, n=10)
# ---- Get latent points and reconstructed images
z_mean, z_var, z = vae.encoder.predict(x_show)
x_reconst = vae.decoder.predict(z)
# ---- Show it
labels=[ str(np.round(z[i],1)) for i in range(10) ]
fidle.scrawler.images(x_show, None, indices='all', columns=10, x_size=2,y_size=2, save_as='05-original')
fidle.scrawler.images(x_reconst, None, indices='all', columns=10, x_size=2,y_size=2, save_as='06-reconstruct')
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
7.3 - Visualization of the latent space¶
n_show = int(20000*scale)
# ---- Select images
x_show, y_show = fidle.utils.pick_dataset(x_data,y_data, n=n_show)
# ---- Get latent points
z_mean, z_var, z = vae.encoder.predict(x_show)
# ---- Show them
fig = plt.figure(figsize=(14, 10))
plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=30)
plt.colorbar()
fidle.scrawler.save_fig('07-Latent-space')
plt.show()
625/625 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step
7.4 - Generative latent space¶
if latent_dim>2:
print('Sorry, This part can only work if the latent space is of dimension 2')
else:
grid_size = 18
grid_scale = 1
# ---- Draw a ppf grid
grid=[]
for y in scipy.stats.norm.ppf(np.linspace(0.99, 0.01, grid_size),scale=grid_scale):
for x in scipy.stats.norm.ppf(np.linspace(0.01, 0.99, grid_size),scale=grid_scale):
grid.append( (x,y) )
grid=np.array(grid)
# ---- Draw latentspoints and grid
fig = plt.figure(figsize=(10, 8))
plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=20)
plt.scatter(grid[:, 0] , grid[:, 1], c = 'black', s=60, linewidth=2, marker='+', alpha=1)
fidle.scrawler.save_fig('08-Latent-grid')
plt.show()
# ---- Plot grid corresponding images
x_reconst = vae.decoder.predict([grid])
fidle.scrawler.images(x_reconst, indices='all', columns=grid_size, x_size=0.5,y_size=0.5, y_padding=0,spines_alpha=0.1, save_as='09-Latent-morphing')
Sorry, This part can only work if the latent space is of dimension 2
fidle.end()
End time : 03/03/24 21:28:38
Duration : 00:05:59 600ms
This notebook ends here :-)
https://fidle.cnrs.fr