No description has been provided for this image

[K3VAE2] - VAE, using a custom model class (MNIST dataset)¶

Construction and training of a VAE, using model subclass, with a latent space of small dimension.

Objectives :¶

  • Understanding and implementing a variational autoencoder neurals network (VAE)
  • Understanding a still more advanced programming model, using a custom model

The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with.
...MNIST with a small scale if you haven't a GPU ;-)

What we're going to do :¶

  • Defining a VAE model
  • Build the model
  • Train it
  • Have a look on the train process

Acknowledgements :¶

Thanks to François Chollet who is at the base of this example (and the creator of Keras !!).
See : https://keras.io/examples/generative/vae

Step 1 - Init python stuff¶

In [1]:
import os
os.environ['KERAS_BACKEND'] = 'torch'

import keras
from keras import layers

import numpy as np

from modules.models    import VAE
from modules.layers    import SamplingLayer
from modules.callbacks import ImagesCallback
from modules.datagen   import MNIST


import matplotlib.pyplot as plt
import scipy.stats
import sys

import fidle

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3VAE2')

VAE.about()


FIDLE - Environment initialization

Version              : 2.3.2
Run id               : K3VAE2
Run dir              : ./run/K3VAE2
Datasets dir         : /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/datasets-fidle
Start time           : 22/12/24 21:36:54
Hostname             : r3i5n3 (Linux)
Tensorflow log level : Info + Warning + Error  (=0)
Update keras cache   : False
Update torch cache   : False
Save figs            : ./run/K3VAE2/figs (True)
keras                : 3.7.0
numpy                : 2.1.2
sklearn              : 1.5.2
yaml                 : 6.0.2
skimage              : 0.24.0
matplotlib           : 3.9.2
pandas               : 2.2.3
torch                : 2.5.0


FIDLE 2024 - VAE

Version              : 2.0
Keras version        : 3.7.0

Step 2 - Parameters¶

scale : with scale=1, we need 1'30s on a GPU V100 ...and >20' on a CPU !
latent_dim : 2 dimensions is small, but usefull to draw !
fit_verbosity: Verbosity of training progress bar: 0=silent, 1=progress bar, 2=One line

loss_weights : Our loss function is the weighted sum of two loss:

  • r_loss which measures the loss during reconstruction.
  • kl_loss which measures the dispersion.

The weights are defined by: loss_weights=[k1,k2] where : total_loss = k1*r_loss + k2*kl_loss
In practice, a value of [1,.06] gives good results here.

In [2]:
latent_dim    = 6
loss_weights  = [1,.06]

scale         = .2
seed          = 123

batch_size    = 64
epochs        = 4
fit_verbosity = 1

Override parameters (batch mode) - Just forget this cell

In [3]:
fidle.override('latent_dim', 'loss_weights', 'scale', 'seed', 'batch_size', 'epochs', 'fit_verbosity')
** Overrided parameters : **
scale                : 1
epochs               : 20
fit_verbosity        : 2

Step 3 - Prepare data¶

MNIST.get_data() return : x_train,y_train, x_test,y_test,
but we only need x_train for our training.

In [4]:
x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )

fidle.scrawler.images(x_data[:20], None, indices='all', columns=10, x_size=1,y_size=1,y_padding=0, save_as='01-original')
Seeded (123)
Dataset loaded.
Concatenated.
Shuffled.
rescaled (1).
Normalized.
Reshaped.
splited (1).
x_train shape is  :  (70000, 28, 28, 1)
x_test  shape is  :  (0, 28, 28, 1)
y_train shape is  :  (70000,)
y_test  shape is  :  (0,)
Blake2b digest is :  0c903710d4d28b01c174
Saved: ./run/K3VAE2/figs/01-original
No description has been provided for this image

Step 4 - Build model¶

In this example, we will use a custom model. For this, we will use :

  • SamplingLayer, which generates a vector z from the parameters z_mean and z_log_var - See : SamplingLayer.py
  • VAE, a custom model with a specific train_step - See : VAE.py

Encoder¶

In [5]:
inputs    = keras.Input(shape=(28, 28, 1))
x         = layers.Conv2D(32, 3, strides=1, padding="same", activation="relu")(inputs)
x         = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x         = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x         = layers.Conv2D(64, 3, strides=1, padding="same", activation="relu")(x)
x         = layers.Flatten()(x)
x         = layers.Dense(16, activation="relu")(x)

z_mean    = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z         = SamplingLayer()([z_mean, z_log_var])

encoder = keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
encoder.compile()

Decoder¶

In [6]:
inputs  = keras.Input(shape=(latent_dim,))
x       = layers.Dense(7 * 7 * 64, activation="relu")(inputs)
x       = layers.Reshape((7, 7, 64))(x)
x       = layers.Conv2DTranspose(64, 3, strides=1, padding="same", activation="relu")(x)
x       = layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)
x       = layers.Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu")(x)
outputs = layers.Conv2DTranspose(1,  3, padding="same", activation="sigmoid")(x)

decoder = keras.Model(inputs, outputs, name="decoder")
decoder.compile()

VAE¶

VAE is a custom model with a specific train_step - See : VAE.py

In [7]:
vae = VAE(encoder, decoder, loss_weights)

vae.compile(optimizer='adam')
Fidle VAE is ready :-)  loss_weights=[1, 0.06]

Step 5 - Train¶

5.1 - Using two nice custom callbacks :-)¶

Two custom callbacks are used:

  • ImagesCallback : qui va sauvegarder des images durant l'apprentissage - See ImagesCallback.py
  • BestModelCallback : qui sauvegardera le meilleur model - See BestModelCallback.py
In [8]:
callback_images      = ImagesCallback(x=x_data, z_dim=latent_dim, nb_images=5, from_z=True, from_random=True, run_dir=run_dir)

callbacks_list = [callback_images]

5.2 - Let's train !¶

With scale=1, need 1'15 on a GPU (V100 at IDRIS) ...or 20' on a CPU

In [9]:
chrono=fidle.Chrono()
chrono.start()

history = vae.fit(x_data, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list, verbose=fit_verbosity)

chrono.show()
Epoch 1/20
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/backend/common/backend_utils.py:91: UserWarning: You might experience inconsistencies across backends when calling conv transpose with kernel_size=3, stride=2, dilation_rate=1, padding=same, output_padding=1.
  warnings.warn(
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 10896.0166
Epoch 2/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1094/1094 - 13s - 12ms/step - loss: 8783.6133
Epoch 3/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8510.3066
Epoch 4/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8344.8555
Epoch 5/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8238.4219
Epoch 6/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8166.6108
Epoch 7/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8111.5840
Epoch 8/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8055.1963
Epoch 9/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8022.8188
Epoch 10/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7979.8218
Epoch 11/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7951.6973
Epoch 12/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7922.7065
Epoch 13/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7900.1572
Epoch 14/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7871.8921
Epoch 15/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7854.2183
Epoch 16/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7835.1177
Epoch 17/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7816.2168
Epoch 18/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7793.7368
Epoch 19/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7780.6123
Epoch 20/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7768.8008
Duration :  256.22 seconds

Step 6 - Training review¶

6.1 - History¶

In [10]:
fidle.scrawler.history(history,  plot={"Loss":['loss']}, save_as='history')
Saved: ./run/K3VAE2/figs/history_0
No description has been provided for this image

6.2 - Reconstruction during training¶

At the end of each epoch, our callback saved some reconstructed images.
Where :
Original image -> encoder -> z -> decoder -> Reconstructed image

In [11]:
images_z, images_r = callback_images.get_images( range(0,epochs,2) )

fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-original')

fidle.utils.subtitle('Encoded/decoded images')
fidle.scrawler.images(images_z, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-reconstruct')

fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)


Original images :

Saved: ./run/K3VAE2/figs/02-original
No description has been provided for this image


Encoded/decoded images

Saved: ./run/K3VAE2/figs/03-reconstruct
No description has been provided for this image


Original images :

No description has been provided for this image

6.3 - Generation (latent -> decoder) during training¶

In [12]:
fidle.utils.subtitle('Generated images from latent space')
fidle.scrawler.images(images_r, None, indices='all', columns=5, x_size=2,y_size=2, save_as='04-encoded')


Generated images from latent space

Saved: ./run/K3VAE2/figs/04-encoded
No description has been provided for this image

6.4 - Save model¶

In [13]:
os.makedirs(f'{run_dir}/models', exist_ok=True)

vae.save(f'{run_dir}/models/vae_model.keras')

Step 7 - Model evaluation¶

7.1 - Reload model¶

In [14]:
vae=VAE()
vae.reload(f'{run_dir}/models/vae_model.keras')
Fidle VAE is ready :-)  loss_weights=[1, 1]
Reloaded.
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 16 variables whereas the saved optimizer has 2 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 12 variables whereas the saved optimizer has 2 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))

7.2 - Image reconstruction¶

In [15]:
# ---- Select few images

x_show = fidle.utils.pick_dataset(x_data, n=10)

# ---- Get latent points and reconstructed images

z_mean, z_var, z  = vae.encoder.predict(x_show)
x_reconst         = vae.decoder.predict(z)

# ---- Show it

labels=[ str(np.round(z[i],1)) for i in range(10) ]
fidle.scrawler.images(x_show,    None, indices='all', columns=10, x_size=2,y_size=2, save_as='05-original')
fidle.scrawler.images(x_reconst, None, indices='all', columns=10, x_size=2,y_size=2, save_as='06-reconstruct')
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
Saved: ./run/K3VAE2/figs/05-original
No description has been provided for this image
Saved: ./run/K3VAE2/figs/06-reconstruct
No description has been provided for this image

7.3 - Visualization of the latent space¶

In [16]:
n_show = int(20000*scale)

# ---- Select images

x_show, y_show = fidle.utils.pick_dataset(x_data,y_data, n=n_show)

# ---- Get latent points

z_mean, z_var, z = vae.encoder.predict(x_show)

# ---- Show them

fig = plt.figure(figsize=(14, 10))
plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=30)
plt.colorbar()
fidle.scrawler.save_fig('07-Latent-space')
plt.show()
  1/625 ━━━━━━━━━━━━━━━━━━━━ 7s 13ms/step

 14/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step 

 27/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step

 40/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step

 53/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step

 66/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step

 79/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step

 92/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step

105/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step

118/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step

131/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

144/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

157/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

170/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

183/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

196/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

209/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

222/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

235/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

248/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

261/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

274/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

287/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

300/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

313/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

326/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

339/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

352/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

365/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step

378/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

391/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

404/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

417/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

430/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

443/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

456/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

469/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

482/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

495/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

508/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

521/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

534/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

547/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

560/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

573/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

586/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

599/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

612/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

625/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

625/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
Saved: ./run/K3VAE2/figs/07-Latent-space
No description has been provided for this image

7.4 - Generative latent space¶

In [17]:
if latent_dim>2:

    print('Sorry, This part can only work if the latent space is of dimension 2')

else:
    
    grid_size   = 18
    grid_scale  = 1

    # ---- Draw a ppf grid

    grid=[]
    for y in scipy.stats.norm.ppf(np.linspace(0.99, 0.01, grid_size),scale=grid_scale):
        for x in scipy.stats.norm.ppf(np.linspace(0.01, 0.99, grid_size),scale=grid_scale):
            grid.append( (x,y) )
    grid=np.array(grid)

    # ---- Draw latentspoints and grid

    fig = plt.figure(figsize=(10, 8))
    plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=20)
    plt.scatter(grid[:, 0] , grid[:, 1], c = 'black', s=60, linewidth=2, marker='+', alpha=1)
    fidle.scrawler.save_fig('08-Latent-grid')
    plt.show()

    # ---- Plot grid corresponding images

    x_reconst = vae.decoder.predict([grid])
    fidle.scrawler.images(x_reconst, indices='all', columns=grid_size, x_size=0.5,y_size=0.5, y_padding=0,spines_alpha=0.1, save_as='09-Latent-morphing')
Sorry, This part can only work if the latent space is of dimension 2
In [18]:
fidle.end()

End time : 22/12/24 21:41:45
Duration : 00:04:51 162ms
This notebook ends here :-)
https://fidle.cnrs.fr



No description has been provided for this image