No description has been provided for this image

[K3VAE1] - First VAE, using functional API (MNIST dataset)¶

Construction and training of a VAE, using functional APPI, with a latent space of small dimension.

Objectives :¶

  • Understanding and implementing a variational autoencoder neurals network (VAE)
  • Understanding Keras functional API, using two custom layers

The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with.
...MNIST with a small scale if you haven't a GPU ;-)

What we're going to do :¶

  • Defining a VAE model
  • Build the model
  • Train it
  • Have a look on the train process

Acknowledgements :¶

Thanks to François Chollet who is at the base of this example (and the creator of Keras !!).
See : https://keras.io/examples/generative/vae

Step 1 - Init python stuff¶

In [1]:
import os
os.environ['KERAS_BACKEND'] = 'torch'

import keras
from keras import layers

import numpy as np

from modules.layers    import SamplingLayer, VariationalLossLayer
from modules.callbacks import ImagesCallback
from modules.datagen   import MNIST

import sys
import fidle

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3VAE1')


FIDLE - Environment initialization

Version              : 2.3.2
Run id               : K3VAE1
Run dir              : ./run/K3VAE1
Datasets dir         : /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/datasets-fidle
Start time           : 22/12/24 21:36:40
Hostname             : r3i6n0 (Linux)
Tensorflow log level : Info + Warning + Error  (=0)
Update keras cache   : False
Update torch cache   : False
Save figs            : ./run/K3VAE1/figs (True)
keras                : 3.7.0
numpy                : 2.1.2
sklearn              : 1.5.2
yaml                 : 6.0.2
skimage              : 0.24.0
matplotlib           : 3.9.2
pandas               : 2.2.3
torch                : 2.5.0

Step 2 - Parameters¶

scale : With scale=1, we need 1'30s on a GPU V100 ...and >20' on a CPU !
latent_dim : 2 dimensions is small, but usefull to draw !
fit_verbosity: Verbosity of training progress bar: 0=silent, 1=progress bar, 2=One line

loss_weights : Our loss function is the weighted sum of two loss:

  • r_loss which measures the loss during reconstruction.
  • kl_loss which measures the dispersion.

The weights are defined by: loss_weights=[k1,k2] where : total_loss = k1*r_loss + k2*kl_loss
In practice, a value of [1,.06] gives good results here.

With scale=0.2, epochs=10 : 3'30 on a laptop

In [2]:
latent_dim    = 2
loss_weights  = [1,.06]

scale         = 0.2
seed          = 123

batch_size    = 64
epochs        = 10
fit_verbosity = 1

Override parameters (batch mode) - Just forget this cell

In [3]:
fidle.override('latent_dim', 'loss_weights', 'scale', 'seed', 'batch_size', 'epochs', 'fit_verbosity')
** Overrided parameters : **
scale                : 1
epochs               : 20
fit_verbosity        : 2

Step 3 - Prepare data¶

MNIST.get_data() return : x_train,y_train, x_test,y_test,
but we only need x_train for our training.

In [4]:
x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )

fidle.scrawler.images(x_data[:20], None, indices='all', columns=10, x_size=1,y_size=1,y_padding=0, save_as='01-original')
Seeded (123)
Dataset loaded.
Concatenated.
Shuffled.
rescaled (1).
Normalized.
Reshaped.
splited (1).
x_train shape is  :  (70000, 28, 28, 1)
x_test  shape is  :  (0, 28, 28, 1)
y_train shape is  :  (70000,)
y_test  shape is  :  (0,)
Blake2b digest is :  0c903710d4d28b01c174
Saved: ./run/K3VAE1/figs/01-original
No description has been provided for this image

Step 4 - Build model¶

In this example, we will use the functional API.
For this, we will use two custom layers :

  • SamplingLayer, which generates a vector z from the parameters z_mean and z_log_var - See : SamplingLayer.py
  • VariationalLossLayer, which allows us to calculate the loss function, loss - See : VariationalLossLayer.py

Encoder¶

In [5]:
inputs    = keras.Input(shape=(28, 28, 1))
x         = layers.Conv2D(32, 3, strides=1, padding="same", activation="relu")(inputs)
x         = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x         = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x         = layers.Conv2D(64, 3, strides=1, padding="same", activation="relu")(x)
x         = layers.Flatten()(x)
x         = layers.Dense(16, activation="relu")(x)

z_mean    = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z         = SamplingLayer()([z_mean, z_log_var])

encoder = keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
# encoder.summary()

Decoder¶

In [6]:
inputs  = keras.Input(shape=(latent_dim,))
x       = layers.Dense(7 * 7 * 64, activation="relu")(inputs)
x       = layers.Reshape((7, 7, 64))(x)
x       = layers.Conv2DTranspose(64, 3, strides=1, padding="same", activation="relu")(x)
x       = layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)
x       = layers.Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu")(x)
outputs = layers.Conv2DTranspose(1,  3, padding="same", activation="sigmoid")(x)

decoder = keras.Model(inputs, outputs, name="decoder")

# decoder.summary()

VAE¶

We will calculate the loss with a specific layer: VariationalLossLayer
See our : modules.layers.VariationalLossLayer.py

In [7]:
inputs = keras.Input(shape=(28, 28, 1))

z_mean, z_log_var, z = encoder(inputs)
outputs              = decoder(z)

outputs = VariationalLossLayer(loss_weights=loss_weights)([inputs, z_mean, z_log_var, outputs])

vae=keras.Model(inputs,outputs)

vae.compile(optimizer='adam', loss=None)

Step 5 - Train¶

5.1 - Using two nice custom callbacks :-)¶

Two custom callbacks are used:

  • ImagesCallback : qui va sauvegarder des images durant l'apprentissage - See ImagesCallback.py
  • BestModelCallback : qui sauvegardera le meilleur model - See BestModelCallback.py
In [8]:
callback_images      = ImagesCallback(x=x_data, z_dim=latent_dim, nb_images=5, from_z=True, from_random=True, run_dir=run_dir)

callbacks_list = [callback_images]

5.2 - Let's train !¶

With scale=1, need 1'15 on a GPU (V100 at IDRIS) ...or 20' on a CPU

In [9]:
chrono=fidle.Chrono()
chrono.start()

history = vae.fit(x_data, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list, verbose=fit_verbosity)

chrono.show()
Epoch 1/20
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/backend/common/backend_utils.py:91: UserWarning: You might experience inconsistencies across backends when calling conv transpose with kernel_size=3, stride=2, dilation_rate=1, padding=same, output_padding=1.
  warnings.warn(
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 11689.1309
Epoch 2/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9795.8291
Epoch 3/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9460.4062
Epoch 4/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9293.0684
Epoch 5/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9182.6094
Epoch 6/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9101.9863
Epoch 7/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 9036.9434
Epoch 8/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8981.9365
Epoch 9/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8932.2832
Epoch 10/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8889.8096
Epoch 11/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8860.3467
Epoch 12/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8824.7744
Epoch 13/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8784.6133
Epoch 14/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8765.9014
Epoch 15/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8738.8848
Epoch 16/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8720.6484
Epoch 17/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8692.5205
Epoch 18/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8674.5615
Epoch 19/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8651.4570
Epoch 20/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 14s - 13ms/step - loss: 8640.1475
Duration :  277.52 seconds

Step 6 - Training review¶

6.1 - History¶

In [10]:
fidle.scrawler.history(history,  plot={"Loss":['loss']}, save_as='history')
Saved: ./run/K3VAE1/figs/history_0
No description has been provided for this image

6.2 - Reconstruction during training¶

At the end of each epoch, our callback saved some reconstructed images.
Where :
Original image -> encoder -> z -> decoder -> Reconstructed image

In [11]:
images_z, images_r = callback_images.get_images( range(0,epochs,2) )

fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)

fidle.utils.subtitle('Encoded/decoded images')
fidle.scrawler.images(images_z, None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-reconstruct')

fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)


Original images :

No description has been provided for this image


Encoded/decoded images

Saved: ./run/K3VAE1/figs/02-reconstruct
No description has been provided for this image


Original images :

No description has been provided for this image

6.3 - Generation (latent -> decoder)¶

In [12]:
fidle.utils.subtitle('Generated images from latent space')
fidle.scrawler.images(images_r, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-generated')


Generated images from latent space

Saved: ./run/K3VAE1/figs/03-generated
No description has been provided for this image

Annexe - Model Save and reload¶

Save our model

In [13]:
os.makedirs(f'{run_dir}/models', exist_ok=True)

filename = run_dir+'/models/my_model.keras'

vae.save(filename)

Reload it

In [14]:
vae_reloaded = keras.models.load_model( filename, 
                                        custom_objects={ 'SamplingLayer': SamplingLayer, 
                                                         'VariationalLossLayer':VariationalLossLayer})

Play with our decoder !

In [15]:
decoder = vae.get_layer('decoder')

img = decoder( np.array([[-1,.1]]))
fidle.scrawler.images(img.detach().cpu().numpy(), x_size=2,y_size=2, save_as='04-example')
Saved: ./run/K3VAE1/figs/04-example
No description has been provided for this image
In [16]:
fidle.end()

End time : 22/12/24 21:41:41
Duration : 00:05:01 652ms
This notebook ends here :-)
https://fidle.cnrs.fr



No description has been provided for this image