[K3AE2] - Building and training an AE denoiser modelĀ¶
Episode 1 : Construction of a denoising autoencoder and training of it with a noisy MNIST dataset.Objectives :Ā¶
- Understanding and implementing a denoizing autoencoder neurals network (AE)
- First overview or example of Keras procedural syntax
The calculation needs being important, it is preferable to use a very simple dataset such as MNIST.
The use of a GPU is often indispensable.
What we're going to do :Ā¶
- Defining an AE model
- Build the model
- Train it
- Follow the learning process with Tensorboard
Data Terminology :Ā¶
clean_train
,clean_test
for noiseless imagesnoisy_train
,noisy_test
for noisy imagesdenoised_test
for denoised images at the output of the model
import os
os.environ['KERAS_BACKEND'] = 'torch'
import keras
import numpy as np
from skimage import io
import random
from keras import layers
from keras.callbacks import ModelCheckpoint, TensorBoard
import os
from importlib import reload
from modules.MNIST import MNIST
from modules.ImagesCallback import ImagesCallback
import fidle
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3AE2')
FIDLE - Environment initialization
Version : 2.3.0 Run id : K3AE2 Run dir : ./run/K3AE2 Datasets dir : /gpfswork/rech/mlh/uja62cb/fidle-project/datasets-fidle Start time : 03/03/24 21:08:52 Hostname : r6i1n1 (Linux) Tensorflow log level : Warning + Error (=1) Update keras cache : False Update torch cache : False Save figs : ./run/K3AE2/figs (True) keras : 3.0.4 numpy : 1.24.4 sklearn : 1.3.2 yaml : 6.0.1 skimage : 0.22.0 matplotlib : 3.8.2 pandas : 2.1.3 torch : 2.1.1
1.2 - ParametersĀ¶
prepared_dataset
: Filename of the prepared dataset (Need 400 Mo, but can be in ./data)
dataset_seed
: Random seed for shuffling dataset
scale
: % of the dataset to use (1. for 100%)
latent_dim
: Dimension of the latent space
train_prop
: Percentage for train (the rest being for the test)
batch_size
: Batch size
epochs
: Nb of epochs for training
fit_verbosity
is the verbosity during training : 0 = silent, 1 = progress bar, 2 = one line per epoch
Note : scale=.2, epoch=20 => 3'30s on a laptop
prepared_dataset = './data/mnist-noisy.h5'
dataset_seed = 123
scale = .1
latent_dim = 10
train_prop = .8
batch_size = 128
epochs = 20
fit_verbosity = 1
Override parameters (batch mode) - Just forget this cell
fidle.override('prepared_dataset', 'dataset_seed', 'scale', 'latent_dim')
fidle.override('train_prop', 'batch_size', 'epochs', 'fit_verbosity')
** Overrided parameters : ** scale : 1 latent_dim : 10 ** Overrided parameters : ** epochs : 20
Step 2 - Retrieve datasetĀ¶
With our MNIST class, in one call, we can reload, rescale, shuffle and split our previously saved dataset :-)
clean_train,clean_test, noisy_train,noisy_test, _,_ = MNIST.reload_prepared_dataset(scale = scale,
train_prop = train_prop,
seed = dataset_seed,
shuffle = True,
filename=prepared_dataset )
Loaded. rescaled (1). Seeded (123) Shuffled. splited (0.8). clean_train shape is : (56000, 28, 28, 1) clean_test shape is : (14000, 28, 28, 1) noisy_train shape is : (56000, 28, 28, 1) noisy_test shape is : (14000, 28, 28, 1) class_train shape is : (56000,) class_test shape is : (14000,) Blake2b digest is : de2af55afacf9fb3ee93
Step 3 - Build modelsĀ¶
EncoderĀ¶
inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z = layers.Dense(latent_dim)(x)
encoder = keras.Model(inputs, z, name="encoder")
# encoder.summary()
DecoderĀ¶
inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(inputs, outputs, name="decoder")
# decoder.summary()
AEĀ¶
inputs = keras.Input(shape=(28, 28, 1))
latents = encoder(inputs)
outputs = decoder(latents)
ae = keras.Model(inputs,outputs, name="ae")
ae.compile(optimizer=keras.optimizers.Adam(), loss='binary_crossentropy')
Step 4 - TrainĀ¶
20' on a CPU
1'12 on a GPU (V100, IDRIS)
# ---- Callback : Images
#
fidle.utils.mkdir( run_dir + '/images')
filename = run_dir + '/images/image-{epoch:03d}-{i:02d}.jpg'
callback_images = ImagesCallback(filename, x=clean_test[:5], encoder=encoder,decoder=decoder)
chrono = fidle.Chrono()
chrono.start()
history = ae.fit(noisy_train, clean_train,
batch_size = batch_size,
epochs = epochs,
verbose = fit_verbosity,
validation_data = (noisy_test, clean_test),
callbacks = [ callback_images ] )
chrono.show()
Epoch 1/20
/gpfswork/rech/mlh/uja62cb/local/fidle-k3/lib/python3.11/site-packages/keras/src/backend/common/backend_utils.py:88: UserWarning: You might experience inconsistencies accross backends when calling conv transpose with kernel_size=3, stride=2, dilation_rate=1, padding=same, output_padding=1. warnings.warn(
1/1 āāāāāāāāāāāāāāāāāāāā 0s 7ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 13s 27ms/step - loss: 0.3270 - val_loss: 0.2320 Epoch 2/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.2121 - val_loss: 0.1848 Epoch 3/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1824 - val_loss: 0.1730 Epoch 4/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1715 - val_loss: 0.1690 Epoch 5/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1677 - val_loss: 0.1659 Epoch 6/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1651 - val_loss: 0.1645 Epoch 7/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1633 - val_loss: 0.1637 Epoch 8/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1629 - val_loss: 0.1633 Epoch 9/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1612 - val_loss: 0.1623 Epoch 10/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1602 - val_loss: 0.1618 Epoch 11/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1595 - val_loss: 0.1614 Epoch 12/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1586 - val_loss: 0.1613 Epoch 13/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1582 - val_loss: 0.1607 Epoch 14/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1578 - val_loss: 0.1608 Epoch 15/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1574 - val_loss: 0.1607 Epoch 16/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1572 - val_loss: 0.1603 Epoch 17/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1562 - val_loss: 0.1607 Epoch 18/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1563 - val_loss: 0.1606 Epoch 19/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1558 - val_loss: 0.1608 Epoch 20/20 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step 438/438 āāāāāāāāāāāāāāāāāāāā 12s 27ms/step - loss: 0.1558 - val_loss: 0.1607 Duration : 237.59 seconds
Save model
os.makedirs(f'{run_dir}/models', exist_ok=True)
encoder.save(f'{run_dir}/models/encoder.keras')
decoder.save(f'{run_dir}/models/decoder.keras')
Step 5 - HistoryĀ¶
fidle.scrawler.history(history, plot={'loss':['loss','val_loss']}, save_as='01-history')
Step 6 - Denoising progressĀ¶
imgs=[]
for epoch in range(0,epochs,2):
for i in range(5):
filename = run_dir + '/images/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)
img = io.imread(filename)
imgs.append(img)
fidle.utils.subtitle('Real images (clean_test) :')
fidle.scrawler.images(clean_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as='02-original-real')
fidle.utils.subtitle('Noisy images (noisy_test) :')
fidle.scrawler.images(noisy_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as='03-original-noisy')
fidle.utils.subtitle('Evolution during the training period (denoised_test) :')
fidle.scrawler.images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, y_padding=0.1, save_as='04-learning')
fidle.utils.subtitle('Noisy images (noisy_test) :')
fidle.scrawler.images(noisy_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as=None)
fidle.utils.subtitle('Real images (clean_test) :')
fidle.scrawler.images(clean_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as=None)
Real images (clean_test) :
Noisy images (noisy_test) :
Evolution during the training period (denoised_test) :
Noisy images (noisy_test) :
Real images (clean_test) :
encoder = keras.models.load_model(f'{run_dir}/models/encoder.keras')
decoder = keras.models.load_model(f'{run_dir}/models/decoder.keras')
inputs = keras.Input(shape=(28, 28, 1))
latents = encoder(inputs)
outputs = decoder(latents)
ae_reloaded = keras.Model(inputs,outputs, name="ae")
7.2 - Let's make a predictionĀ¶
denoised_test = ae_reloaded.predict(noisy_test, verbose=0)
print('Denoised images (denoised_test) shape : ',denoised_test.shape)
Denoised images (denoised_test) shape : (14000, 28, 28, 1)
7.3 - Denoised imagesĀ¶
i=random.randint(0,len(denoised_test)-8)
j=i+8
fidle.utils.subtitle('Noisy test images (input):')
fidle.scrawler.images(noisy_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='05-test-noisy')
fidle.utils.subtitle('Denoised images (output):')
fidle.scrawler.images(denoised_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='06-test-predict')
fidle.utils.subtitle('Real test images :')
fidle.scrawler.images(clean_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='07-test-real')
Noisy test images (input):
Denoised images (output):
Real test images :
fidle.end()
End time : 03/03/24 21:13:00
Duration : 00:04:08 633ms
This notebook ends here :-)
https://fidle.cnrs.fr