[K3VAE3] - Analysis of the VAE's latent space of MNIST dataset¶
Visualization and analysis of the VAE's latent space of the dataset MNISTObjectives :¶
- First data generation from latent space
- Understanding of underlying principles
- Model management
Here, we don't consume data anymore, but we generate them ! ;-)
What we're going to do :¶
- Load a saved model
- Reconstruct some images
- Latent space visualization
- Matrix of generated images
Step 1 - Init python stuff¶
1.1 - Init python¶
In [1]:
import os
os.environ['KERAS_BACKEND'] = 'torch'
import keras
from keras import layers
import numpy as np
from modules.models import VAE
from modules.datagen import MNIST
import matplotlib
import matplotlib.pyplot as plt
from barviz import Simplex
from barviz import Collection
import sys
import fidle
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3VAE3')
FIDLE - Environment initialization
Version : 2.3.0 Run id : K3VAE3 Run dir : ./run/K3VAE3 Datasets dir : /gpfswork/rech/mlh/uja62cb/fidle-project/datasets-fidle Start time : 03/03/24 21:29:39 Hostname : r6i0n6 (Linux) Tensorflow log level : Warning + Error (=1) Update keras cache : False Update torch cache : False Save figs : ./run/K3VAE3/figs (True) keras : 3.0.4 numpy : 1.24.4 sklearn : 1.3.2 yaml : 6.0.1 matplotlib : 3.8.2 plotly : 5.18.0 pandas : 2.1.3 torch : 2.1.1
1.2 - Parameters¶
In [2]:
scale = 1
seed = 123
models_dir = './run/K3VAE2'
Override parameters (batch mode) - Just forget this cell
In [3]:
fidle.override('scale', 'seed', 'models_dir')
** Overrided parameters : ** scale : 1
Step 2 - Get data¶
In [4]:
x_data, y_data, _,_ = MNIST.get_data(seed=seed, scale=scale, train_prop=1 )
Seeded (123) Dataset loaded. Concatenated. Shuffled. rescaled (1). Normalized. Reshaped. splited (1). x_train shape is : (70000, 28, 28, 1) x_test shape is : (0, 28, 28, 1) y_train shape is : (70000,) y_test shape is : (0,) Blake2b digest is : 0c903710d4d28b01c174
Step 3 - Reload best model¶
In [5]:
vae=VAE()
vae.reload(f'{models_dir}/models/vae_model')
Fidle VAE is ready :-) loss_weights=[1, 1] Reloaded.
/gpfswork/rech/mlh/uja62cb/local/fidle-k3/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:394: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 16 variables whereas the saved optimizer has 2 variables. trackable.load_own_variables(weights_store.get(inner_path)) /gpfswork/rech/mlh/uja62cb/local/fidle-k3/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:394: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 12 variables whereas the saved optimizer has 2 variables. trackable.load_own_variables(weights_store.get(inner_path))
Step 4 - Image reconstruction¶
In [6]:
# ---- Select few images
x_show = fidle.utils.pick_dataset(x_data, n=10)
# ---- Get latent points and reconstructed images
z_mean, z_var, z = vae.encoder.predict(x_show, verbose=0)
x_reconst = vae.decoder.predict(z, verbose=0)
latent_dim = z.shape[1]
# ---- Show it
labels=[ str(np.round(z[i],1)) for i in range(10) ]
fidle.utils.subtitle('Originals :')
fidle.scrawler.images(x_show, None, indices='all', columns=10, x_size=2,y_size=2, save_as='01-original')
fidle.utils.subtitle('Reconstructed :')
fidle.scrawler.images(x_reconst, None, indices='all', columns=10, x_size=2,y_size=2, save_as='02-reconstruct')
/gpfswork/rech/mlh/uja62cb/local/fidle-k3/lib/python3.11/site-packages/keras/src/backend/common/backend_utils.py:88: UserWarning: You might experience inconsistencies accross backends when calling conv transpose with kernel_size=3, stride=2, dilation_rate=1, padding=same, output_padding=1. warnings.warn(
Originals :
Saved: ./run/K3VAE3/figs/01-original
Reconstructed :
Saved: ./run/K3VAE3/figs/02-reconstruct
Step 5 - Visualizing the latent space¶
In [7]:
n_show = 20000
# ---- Select images
x_show, y_show = fidle.utils.pick_dataset(x_data,y_data, n=n_show)
# ---- Get latent points
z_mean, z_var, z = vae.encoder.predict(x_show, verbose=0)
5.1 - Classic 2d visualisaton¶
In [8]:
fig = plt.figure(figsize=(14, 10))
plt.scatter(z[:, 2] , z[:, 4], c=y_show, cmap= 'tab10', alpha=0.5, s=30)
plt.colorbar()
fidle.scrawler.save_fig('03-Latent-space')
plt.show()
Saved: ./run/K3VAE3/figs/03-Latent-space
5.2 - Simplex visualisaton¶
In [9]:
if latent_dim<4:
print('Sorry, This part can only work if the latent space is greater than 3')
else:
# ---- Softmax rescale
#
zs = np.exp(z)/np.sum(np.exp(z),axis=1,keepdims=True)
# zc = zs * 1/np.max(zs)
# ---- Create collection
#
c = Collection(zs, colors=y_show, labels=y_show)
c.attrs.markers_colormap = {'colorscale':'Rainbow','cmin':0,'cmax':latent_dim}
c.attrs.markers_size = 5
c.attrs.markers_border_width = 0
c.attrs.markers_opacity = 0.8
s = Simplex.build(latent_dim)
s.attrs.width = 1000
s.attrs.height = 1000
s.plot(c)
Step 6 - Generate from latent space (latent_dim==2)¶
In [10]:
if latent_dim>2:
print('Sorry, This part can only work if the latent space is of dimension 2')
else:
grid_size = 14
grid_scale = 1.
# ---- Draw a ppf grid
grid=[]
for y in scipy.stats.norm.ppf(np.linspace(0.99, 0.01, grid_size),scale=grid_scale):
for x in scipy.stats.norm.ppf(np.linspace(0.01, 0.99, grid_size),scale=grid_scale):
grid.append( (x,y) )
grid=np.array(grid)
# ---- Draw latentspoints and grid
fig = plt.figure(figsize=(12, 10))
plt.scatter(z[:, 0] , z[:, 1], c=y_show, cmap= 'tab10', alpha=0.5, s=20)
plt.scatter(grid[:, 0] , grid[:, 1], c = 'black', s=60, linewidth=2, marker='+', alpha=1)
fidle.scrawler.save_fig('04-Latent-grid')
plt.show()
# ---- Plot grid corresponding images
x_reconst = vae.decoder.predict([grid])
fidle.scrawler.images(x_reconst, indices='all', columns=grid_size, x_size=0.5,y_size=0.5, y_padding=0,spines_alpha=0.1, save_as='05-Latent-morphing')
Sorry, This part can only work if the latent space is of dimension 2
In [11]:
fidle.end()
End time : 03/03/24 21:30:01
Duration : 00:00:22 594ms
This notebook ends here :-)
https://fidle.cnrs.fr