[K3VAE1] - First VAE, using functional API (MNIST dataset)¶
Construction and training of a VAE, using functional APPI, with a latent space of small dimension.Objectives :¶
- Understanding and implementing a variational autoencoder neurals network (VAE)
- Understanding Keras functional API, using two custom layers
The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with.
...MNIST with a small scale if you haven't a GPU ;-)
What we're going to do :¶
- Defining a VAE model
- Build the model
- Train it
- Have a look on the train process
Acknowledgements :¶
Thanks to François Chollet who is at the base of this example (and the creator of Keras !!).
See : https://keras.io/examples/generative/vae
Step 1 - Init python stuff¶
import os
os.environ['KERAS_BACKEND'] = 'torch'
import keras
from keras import layers
import numpy as np
from modules.layers import SamplingLayer, VariationalLossLayer
from modules.callbacks import ImagesCallback
from modules.datagen import MNIST
import sys
import fidle
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3VAE1')
FIDLE - Environment initialization
Version : 2.3.2 Run id : K3VAE1 Run dir : ./run/K3VAE1 Datasets dir : /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/datasets-fidle Start time : 22/12/24 21:36:40 Hostname : r3i6n0 (Linux) Tensorflow log level : Info + Warning + Error (=0) Update keras cache : False Update torch cache : False Save figs : ./run/K3VAE1/figs (True) keras : 3.7.0 numpy : 2.1.2 sklearn : 1.5.2 yaml : 6.0.2 skimage : 0.24.0 matplotlib : 3.9.2 pandas : 2.2.3 torch : 2.5.0
Step 2 - Parameters¶
scale
: With scale=1, we need 1'30s on a GPU V100 ...and >20' on a CPU !
latent_dim
: 2 dimensions is small, but usefull to draw !
fit_verbosity
: Verbosity of training progress bar: 0=silent, 1=progress bar, 2=One line
loss_weights
: Our loss function is the weighted sum of two loss:
r_loss
which measures the loss during reconstruction.kl_loss
which measures the dispersion.
The weights are defined by: loss_weights=[k1,k2]
where : total_loss = k1*r_loss + k2*kl_loss
In practice, a value of [1,.06] gives good results here.
With scale=0.2, epochs=10 : 3'30 on a laptop
latent_dim = 2
loss_weights = [1,.06]
scale = 0.2
seed = 123
batch_size = 64
epochs = 10
fit_verbosity = 1
Override parameters (batch mode) - Just forget this cell
fidle.override('latent_dim', 'loss_weights', 'scale', 'seed', 'batch_size', 'epochs', 'fit_verbosity')
** Overrided parameters : ** scale : 1 epochs : 20 fit_verbosity : 2
Step 3 - Prepare data¶
MNIST.get_data()
return : x_train,y_train, x_test,y_test
,
but we only need x_train for our training.
x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )
fidle.scrawler.images(x_data[:20], None, indices='all', columns=10, x_size=1,y_size=1,y_padding=0, save_as='01-original')
Seeded (123)
Dataset loaded. Concatenated. Shuffled. rescaled (1). Normalized. Reshaped. splited (1).
x_train shape is : (70000, 28, 28, 1) x_test shape is : (0, 28, 28, 1) y_train shape is : (70000,) y_test shape is : (0,) Blake2b digest is : 0c903710d4d28b01c174
Step 4 - Build model¶
In this example, we will use the functional API.
For this, we will use two custom layers :
SamplingLayer
, which generates a vector z from the parameters z_mean and z_log_var - See : SamplingLayer.pyVariationalLossLayer
, which allows us to calculate the loss function, loss - See : VariationalLossLayer.py
Encoder¶
inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, strides=1, padding="same", activation="relu")(inputs)
x = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2D(64, 3, strides=1, padding="same", activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = SamplingLayer()([z_mean, z_log_var])
encoder = keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
# encoder.summary()
Decoder¶
inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, strides=1, padding="same", activation="relu")(x)
x = layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu")(x)
outputs = layers.Conv2DTranspose(1, 3, padding="same", activation="sigmoid")(x)
decoder = keras.Model(inputs, outputs, name="decoder")
# decoder.summary()
VAE¶
We will calculate the loss with a specific layer: VariationalLossLayer
See our : modules.layers.VariationalLossLayer.py
inputs = keras.Input(shape=(28, 28, 1))
z_mean, z_log_var, z = encoder(inputs)
outputs = decoder(z)
outputs = VariationalLossLayer(loss_weights=loss_weights)([inputs, z_mean, z_log_var, outputs])
vae=keras.Model(inputs,outputs)
vae.compile(optimizer='adam', loss=None)
Step 5 - Train¶
5.1 - Using two nice custom callbacks :-)¶
Two custom callbacks are used:
ImagesCallback
: qui va sauvegarder des images durant l'apprentissage - See ImagesCallback.pyBestModelCallback
: qui sauvegardera le meilleur model - See BestModelCallback.py
callback_images = ImagesCallback(x=x_data, z_dim=latent_dim, nb_images=5, from_z=True, from_random=True, run_dir=run_dir)
callbacks_list = [callback_images]
5.2 - Let's train !¶
With scale=1
, need 1'15 on a GPU (V100 at IDRIS) ...or 20' on a CPU
chrono=fidle.Chrono()
chrono.start()
history = vae.fit(x_data, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list, verbose=fit_verbosity)
chrono.show()
Epoch 1/20
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/backend/common/backend_utils.py:91: UserWarning: You might experience inconsistencies across backends when calling conv transpose with kernel_size=3, stride=2, dilation_rate=1, padding=same, output_padding=1. warnings.warn(
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 11689.1309
Epoch 2/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9795.8291
Epoch 3/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9460.4062
Epoch 4/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9293.0684
Epoch 5/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9182.6094
Epoch 6/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9101.9863
Epoch 7/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9036.9434
Epoch 8/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8981.9365
Epoch 9/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8932.2832
Epoch 10/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8889.8096
Epoch 11/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8860.3467
Epoch 12/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8824.7744
Epoch 13/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8784.6133
Epoch 14/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8765.9014
Epoch 15/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8738.8848
Epoch 16/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8720.6484
Epoch 17/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8692.5205
Epoch 18/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8674.5615
Epoch 19/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8651.4570
Epoch 20/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8640.1475
Duration : 277.52 seconds
fidle.scrawler.history(history, plot={"Loss":['loss']}, save_as='history')
6.2 - Reconstruction during training¶
At the end of each epoch, our callback saved some reconstructed images.
Where :
Original image -> encoder -> z -> decoder -> Reconstructed image
images_z, images_r = callback_images.get_images( range(0,epochs,2) )
fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)
fidle.utils.subtitle('Encoded/decoded images')
fidle.scrawler.images(images_z, None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-reconstruct')
fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)
Original images :
Encoded/decoded images
Original images :
6.3 - Generation (latent -> decoder)¶
fidle.utils.subtitle('Generated images from latent space')
fidle.scrawler.images(images_r, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-generated')
Generated images from latent space
Annexe - Model Save and reload¶
Save our model
os.makedirs(f'{run_dir}/models', exist_ok=True)
filename = run_dir+'/models/my_model.keras'
vae.save(filename)
Reload it
vae_reloaded = keras.models.load_model( filename,
custom_objects={ 'SamplingLayer': SamplingLayer,
'VariationalLossLayer':VariationalLossLayer})
Play with our decoder !
decoder = vae.get_layer('decoder')
img = decoder( np.array([[-1,.1]]))
fidle.scrawler.images(img.detach().cpu().numpy(), x_size=2,y_size=2, save_as='04-example')
fidle.end()
End time : 22/12/24 21:41:41
Duration : 00:05:01 652ms
This notebook ends here :-)
https://fidle.cnrs.fr