[K3VAE2] - VAE, using a custom model class (MNIST dataset)¶
Construction and training of a VAE, using model subclass, with a latent space of small dimension.Objectives :¶
- Understanding and implementing a variational autoencoder neurals network (VAE)
- Understanding a still more advanced programming model, using a custom model
The calculation needs being important, it is preferable to use a very simple dataset such as MNIST to start with.
...MNIST with a small scale if you haven't a GPU ;-)
What we're going to do :¶
- Defining a VAE model
- Build the model
- Train it
- Have a look on the train process
Acknowledgements :¶
Thanks to François Chollet who is at the base of this example (and the creator of Keras !!).
See : https://keras.io/examples/generative/vae
Step 1 - Init python stuff¶
import os
os.environ['KERAS_BACKEND'] = 'torch'
import keras
from keras import layers
import numpy as np
from modules.models import VAE
from modules.layers import SamplingLayer
from modules.callbacks import ImagesCallback
from modules.datagen import MNIST
import matplotlib.pyplot as plt
import scipy.stats
import sys
import fidle
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3VAE2')
VAE.about()
FIDLE - Environment initialization
Version : 2.3.2 Run id : K3VAE2 Run dir : ./run/K3VAE2 Datasets dir : /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/datasets-fidle Start time : 22/12/24 21:36:54 Hostname : r3i5n3 (Linux) Tensorflow log level : Info + Warning + Error (=0) Update keras cache : False Update torch cache : False Save figs : ./run/K3VAE2/figs (True) keras : 3.7.0 numpy : 2.1.2 sklearn : 1.5.2 yaml : 6.0.2 skimage : 0.24.0 matplotlib : 3.9.2 pandas : 2.2.3 torch : 2.5.0
FIDLE 2024 - VAE
Version : 2.0 Keras version : 3.7.0
Step 2 - Parameters¶
scale
: with scale=1, we need 1'30s on a GPU V100 ...and >20' on a CPU !
latent_dim
: 2 dimensions is small, but usefull to draw !
fit_verbosity
: Verbosity of training progress bar: 0=silent, 1=progress bar, 2=One line
loss_weights
: Our loss function is the weighted sum of two loss:
r_loss
which measures the loss during reconstruction.kl_loss
which measures the dispersion.
The weights are defined by: loss_weights=[k1,k2]
where : total_loss = k1*r_loss + k2*kl_loss
In practice, a value of [1,.06] gives good results here.
latent_dim = 6
loss_weights = [1,.06]
scale = .2
seed = 123
batch_size = 64
epochs = 4
fit_verbosity = 1
Override parameters (batch mode) - Just forget this cell
fidle.override('latent_dim', 'loss_weights', 'scale', 'seed', 'batch_size', 'epochs', 'fit_verbosity')
** Overrided parameters : ** scale : 1 epochs : 20 fit_verbosity : 2
Step 3 - Prepare data¶
MNIST.get_data()
return : x_train,y_train, x_test,y_test
,
but we only need x_train for our training.
x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )
fidle.scrawler.images(x_data[:20], None, indices='all', columns=10, x_size=1,y_size=1,y_padding=0, save_as='01-original')
Seeded (123)
Dataset loaded. Concatenated. Shuffled. rescaled (1). Normalized. Reshaped. splited (1).
x_train shape is : (70000, 28, 28, 1) x_test shape is : (0, 28, 28, 1) y_train shape is : (70000,) y_test shape is : (0,) Blake2b digest is : 0c903710d4d28b01c174
Step 4 - Build model¶
In this example, we will use a custom model. For this, we will use :
SamplingLayer
, which generates a vector z from the parameters z_mean and z_log_var - See : SamplingLayer.pyVAE
, a custom model with a specific train_step - See : VAE.py
Encoder¶
inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, strides=1, padding="same", activation="relu")(inputs)
x = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2D(64, 3, strides=1, padding="same", activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = SamplingLayer()([z_mean, z_log_var])
encoder = keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
encoder.compile()
Decoder¶
inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, strides=1, padding="same", activation="relu")(x)
x = layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)
x = layers.Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu")(x)
outputs = layers.Conv2DTranspose(1, 3, padding="same", activation="sigmoid")(x)
decoder = keras.Model(inputs, outputs, name="decoder")
decoder.compile()
vae = VAE(encoder, decoder, loss_weights)
vae.compile(optimizer='adam')
Fidle VAE is ready :-) loss_weights=[1, 0.06]
Step 5 - Train¶
5.1 - Using two nice custom callbacks :-)¶
Two custom callbacks are used:
ImagesCallback
: qui va sauvegarder des images durant l'apprentissage - See ImagesCallback.pyBestModelCallback
: qui sauvegardera le meilleur model - See BestModelCallback.py
callback_images = ImagesCallback(x=x_data, z_dim=latent_dim, nb_images=5, from_z=True, from_random=True, run_dir=run_dir)
callbacks_list = [callback_images]
5.2 - Let's train !¶
With scale=1
, need 1'15 on a GPU (V100 at IDRIS) ...or 20' on a CPU
chrono=fidle.Chrono()
chrono.start()
history = vae.fit(x_data, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list, verbose=fit_verbosity)
chrono.show()
Epoch 1/20
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/backend/common/backend_utils.py:91: UserWarning: You might experience inconsistencies across backends when calling conv transpose with kernel_size=3, stride=2, dilation_rate=1, padding=same, output_padding=1. warnings.warn(
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 10896.0166
Epoch 2/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1094/1094 - 13s - 12ms/step - loss: 8783.6133
Epoch 3/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8510.3066
Epoch 4/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8344.8555
Epoch 5/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8238.4219
Epoch 6/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8166.6108
Epoch 7/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8111.5840
Epoch 8/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8055.1963
Epoch 9/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 8022.8188
Epoch 10/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7979.8218
Epoch 11/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7951.6973
Epoch 12/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7922.7065
Epoch 13/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7900.1572
Epoch 14/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7871.8921
Epoch 15/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7854.2183
Epoch 16/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7835.1177
Epoch 17/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7816.2168
Epoch 18/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7793.7368
Epoch 19/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7780.6123
Epoch 20/20
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
1094/1094 - 13s - 12ms/step - loss: 7768.8008
Duration : 256.22 seconds
fidle.scrawler.history(history, plot={"Loss":['loss']}, save_as='history')
6.2 - Reconstruction during training¶
At the end of each epoch, our callback saved some reconstructed images.
Where :
Original image -> encoder -> z -> decoder -> Reconstructed image
images_z, images_r = callback_images.get_images( range(0,epochs,2) )
fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as='02-original')
fidle.utils.subtitle('Encoded/decoded images')
fidle.scrawler.images(images_z, None, indices='all', columns=5, x_size=2,y_size=2, save_as='03-reconstruct')
fidle.utils.subtitle('Original images :')
fidle.scrawler.images(x_data[:5], None, indices='all', columns=5, x_size=2,y_size=2, save_as=None)
Original images :
Encoded/decoded images
Original images :
6.3 - Generation (latent -> decoder) during training¶
fidle.utils.subtitle('Generated images from latent space')
fidle.scrawler.images(images_r, None, indices='all', columns=5, x_size=2,y_size=2, save_as='04-encoded')
Generated images from latent space
6.4 - Save model¶
os.makedirs(f'{run_dir}/models', exist_ok=True)
vae.save(f'{run_dir}/models/vae_model.keras')
Step 7 - Model evaluation¶
7.1 - Reload model¶
vae=VAE()
vae.reload(f'{run_dir}/models/vae_model.keras')
Fidle VAE is ready :-) loss_weights=[1, 1] Reloaded.
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 16 variables whereas the saved optimizer has 2 variables. saveable.load_own_variables(weights_store.get(inner_path)) /lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 12 variables whereas the saved optimizer has 2 variables. saveable.load_own_variables(weights_store.get(inner_path))
7.2 - Image reconstruction¶
# ---- Select few images
x_show = fidle.utils.pick_dataset(x_data, n=10)
# ---- Get latent points and reconstructed images
z_mean, z_var, z = vae.encoder.predict(x_show)
x_reconst = vae.decoder.predict(z)
# ---- Show it
labels=[ str(np.round(z[i],1)) for i in range(10) ]
fidle.scrawler.images(x_show, None, indices='all', columns=10, x_size=2,y_size=2, save_as='05-original')
fidle.scrawler.images(x_reconst, None, indices='all', columns=10, x_size=2,y_size=2, save_as='06-reconstruct')
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
7.3 - Visualization of the latent space¶
n_show = int(20000*scale)
# ---- Select images
x_show, y_show = fidle.utils.pick_dataset(x_data,y_data, n=n_show)
# ---- Get latent points
z_mean, z_var, z = vae.encoder.predict(x_show)
# ---- Show them
fig = plt.figure(figsize=(14, 10))
plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=30)
plt.colorbar()
fidle.scrawler.save_fig('07-Latent-space')
plt.show()
1/625 ━━━━━━━━━━━━━━━━━━━━ 7s 13ms/step
14/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
27/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
40/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
53/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
66/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
79/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
92/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
105/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
118/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
131/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
144/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
157/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
170/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
183/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
196/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
209/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
222/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
235/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
248/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
261/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
274/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
287/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
300/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
313/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
326/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
339/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
352/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
365/625 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step
378/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
391/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
404/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
417/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
430/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
443/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
456/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
469/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
482/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
495/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
508/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
521/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
534/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
547/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
560/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
573/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
586/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
599/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
612/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
625/625 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
625/625 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step
7.4 - Generative latent space¶
if latent_dim>2:
print('Sorry, This part can only work if the latent space is of dimension 2')
else:
grid_size = 18
grid_scale = 1
# ---- Draw a ppf grid
grid=[]
for y in scipy.stats.norm.ppf(np.linspace(0.99, 0.01, grid_size),scale=grid_scale):
for x in scipy.stats.norm.ppf(np.linspace(0.01, 0.99, grid_size),scale=grid_scale):
grid.append( (x,y) )
grid=np.array(grid)
# ---- Draw latentspoints and grid
fig = plt.figure(figsize=(10, 8))
plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=20)
plt.scatter(grid[:, 0] , grid[:, 1], c = 'black', s=60, linewidth=2, marker='+', alpha=1)
fidle.scrawler.save_fig('08-Latent-grid')
plt.show()
# ---- Plot grid corresponding images
x_reconst = vae.decoder.predict([grid])
fidle.scrawler.images(x_reconst, indices='all', columns=grid_size, x_size=0.5,y_size=0.5, y_padding=0,spines_alpha=0.1, save_as='09-Latent-morphing')
Sorry, This part can only work if the latent space is of dimension 2
fidle.end()
End time : 22/12/24 21:41:45
Duration : 00:04:51 162ms
This notebook ends here :-)
https://fidle.cnrs.fr