No description has been provided for this image

[K3VAE3] - Analysis of the VAE's latent space of MNIST dataset¶

Visualization and analysis of the VAE's latent space of the dataset MNIST

Objectives :¶

  • First data generation from latent space
  • Understanding of underlying principles
  • Model management

Here, we don't consume data anymore, but we generate them ! ;-)

What we're going to do :¶

  • Load a saved model
  • Reconstruct some images
  • Latent space visualization
  • Matrix of generated images

Step 1 - Init python stuff¶

1.1 - Init python¶

In [1]:
import os
os.environ['KERAS_BACKEND'] = 'torch'

import keras
from keras import layers

import numpy as np

from modules.models    import VAE
from modules.datagen   import MNIST

import matplotlib
import matplotlib.pyplot as plt
from barviz import Simplex
from barviz import Collection

import sys
import fidle

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3VAE3')


FIDLE - Environment initialization

Version              : 2.3.2
Run id               : K3VAE3
Run dir              : ./run/K3VAE3
Datasets dir         : /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/datasets-fidle
Start time           : 22/12/24 21:41:57
Hostname             : r3i5n3 (Linux)
Tensorflow log level : Info + Warning + Error  (=0)
Update keras cache   : False
Update torch cache   : False
Save figs            : ./run/K3VAE3/figs (True)
keras                : 3.7.0
numpy                : 2.1.2
sklearn              : 1.5.2
yaml                 : 6.0.2
matplotlib           : 3.9.2
plotly               : 5.24.1
pandas               : 2.2.3
torch                : 2.5.0

1.2 - Parameters¶

In [2]:
scale      = 1
seed       = 123
models_dir = './run/K3VAE2'

Override parameters (batch mode) - Just forget this cell

In [3]:
fidle.override('scale', 'seed', 'models_dir')
** Overrided parameters : **
scale                : 1

Step 2 - Get data¶

In [4]:
x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )
Seeded (123)
Dataset loaded.
Concatenated.
Shuffled.
rescaled (1).
Normalized.
Reshaped.
splited (1).
x_train shape is  :  (70000, 28, 28, 1)
x_test  shape is  :  (0, 28, 28, 1)
y_train shape is  :  (70000,)
y_test  shape is  :  (0,)
Blake2b digest is :  0c903710d4d28b01c174

Step 3 - Reload best model¶

In [5]:
vae=VAE()
vae.reload(f'{models_dir}/models/vae_model')
Fidle VAE is ready :-)  loss_weights=[1, 1]
Reloaded.
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 16 variables whereas the saved optimizer has 2 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 12 variables whereas the saved optimizer has 2 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))

Step 4 - Image reconstruction¶

In [6]:
# ---- Select few images

x_show = fidle.utils.pick_dataset(x_data, n=10)

# ---- Get latent points and reconstructed images

z_mean, z_var, z  = vae.encoder.predict(x_show, verbose=0)
x_reconst         = vae.decoder.predict(z,      verbose=0)

latent_dim        = z.shape[1]

# ---- Show it

labels=[ str(np.round(z[i],1)) for i in range(10) ]
fidle.utils.subtitle('Originals :')
fidle.scrawler.images(x_show,    None, indices='all', columns=10, x_size=2,y_size=2, save_as='01-original')
fidle.utils.subtitle('Reconstructed :')
fidle.scrawler.images(x_reconst, None, indices='all', columns=10, x_size=2,y_size=2, save_as='02-reconstruct')
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/backend/common/backend_utils.py:91: UserWarning: You might experience inconsistencies across backends when calling conv transpose with kernel_size=3, stride=2, dilation_rate=1, padding=same, output_padding=1.
  warnings.warn(


Originals :

Saved: ./run/K3VAE3/figs/01-original
No description has been provided for this image


Reconstructed :

Saved: ./run/K3VAE3/figs/02-reconstruct
No description has been provided for this image

Step 5 - Visualizing the latent space¶

In [7]:
n_show = min( 20000, len(x_data) )

# ---- Select images

x_show, y_show = fidle.utils.pick_dataset(x_data,y_data, n=n_show)

# ---- Get latent points

z_mean, z_var, z = vae.encoder.predict(x_show, verbose=0)

5.1 - Classic 2d visualisaton¶

In [8]:
fig = plt.figure(figsize=(14, 10))
plt.scatter(z[:, 2] , z[:, 4], c=y_show, cmap= 'tab10', alpha=0.5, s=30)
plt.colorbar()
fidle.scrawler.save_fig('03-Latent-space')
plt.show()
Saved: ./run/K3VAE3/figs/03-Latent-space
No description has been provided for this image

5.2 - Simplex visualisaton¶

In [9]:
if latent_dim<4:

    print('Sorry, This part can only work if the latent space is greater than 3')

else:

    # ---- Softmax rescale
    #
    zs = np.exp(z)/np.sum(np.exp(z),axis=1,keepdims=True)
    # zc  = zs * 1/np.max(zs)

    # ---- Create collection
    #
    c = Collection(zs, colors=y_show, labels=y_show)
    c.attrs.markers_colormap     = {'colorscale':'Rainbow','cmin':0,'cmax':latent_dim}
    c.attrs.markers_size         = 5
    c.attrs.markers_border_width = 0
    c.attrs.markers_opacity      = 0.8

    s = Simplex.build(latent_dim)
    s.attrs.width  = 1000
    s.attrs.height = 1000
    s.plot(c)

Step 6 - Generate from latent space (latent_dim==2)¶

In [10]:
if latent_dim>2:

    print('Sorry, This part can only work if the latent space is of dimension 2')

else:

    grid_size   = 14
    grid_scale  = 1.

    # ---- Draw a ppf grid

    grid=[]
    for y in scipy.stats.norm.ppf(np.linspace(0.99, 0.01, grid_size),scale=grid_scale):
        for x in scipy.stats.norm.ppf(np.linspace(0.01, 0.99, grid_size),scale=grid_scale):
            grid.append( (x,y) )
    grid=np.array(grid)

    # ---- Draw latentspoints and grid

    fig = plt.figure(figsize=(12, 10))
    plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=20)
    plt.scatter(grid[:, 0] , grid[:, 1], c = 'black', s=60, linewidth=2, marker='+', alpha=1)
    fidle.scrawler.save_fig('04-Latent-grid')
    plt.show()

    # ---- Plot grid corresponding images

    x_reconst = vae.decoder.predict([grid])
    fidle.scrawler.images(x_reconst, indices='all', columns=grid_size, x_size=0.5,y_size=0.5, y_padding=0,spines_alpha=0.1, save_as='05-Latent-morphing')
Sorry, This part can only work if the latent space is of dimension 2
In [11]:
fidle.end()

End time : 22/12/24 21:42:13
Duration : 00:00:16 866ms
This notebook ends here :-)
https://fidle.cnrs.fr



No description has been provided for this image