[K3AE2] - Building and training an AE denoiser modelĀ¶
Episode 1 : Construction of a denoising autoencoder and training of it with a noisy MNIST dataset.Objectives :Ā¶
- Understanding and implementing a denoizing autoencoder neurals network (AE)
- First overview or example of Keras procedural syntax
The calculation needs being important, it is preferable to use a very simple dataset such as MNIST.
The use of a GPU is often indispensable.
What we're going to do :Ā¶
- Defining an AE model
- Build the model
- Train it
- Follow the learning process with Tensorboard
Data Terminology :Ā¶
clean_train
,clean_test
for noiseless imagesnoisy_train
,noisy_test
for noisy imagesdenoised_test
for denoised images at the output of the model
import os
os.environ['KERAS_BACKEND'] = 'torch'
import keras
import numpy as np
from skimage import io
import random
from keras import layers
from keras.callbacks import ModelCheckpoint, TensorBoard
import os
from importlib import reload
from modules.MNIST import MNIST
from modules.ImagesCallback import ImagesCallback
import fidle
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3AE2')
FIDLE - Environment initialization
Version : 2.3.2 Run id : K3AE2 Run dir : ./run/K3AE2 Datasets dir : /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/datasets-fidle Start time : 22/12/24 21:24:22 Hostname : r3i6n0 (Linux) Tensorflow log level : Info + Warning + Error (=0) Update keras cache : False Update torch cache : False Save figs : ./run/K3AE2/figs (True) keras : 3.7.0 numpy : 2.1.2 sklearn : 1.5.2 yaml : 6.0.2 skimage : 0.24.0 matplotlib : 3.9.2 pandas : 2.2.3 torch : 2.5.0
1.2 - ParametersĀ¶
prepared_dataset
: Filename of the prepared dataset (Need 400 Mo, but can be in ./data)
dataset_seed
: Random seed for shuffling dataset
scale
: % of the dataset to use (1. for 100%)
latent_dim
: Dimension of the latent space
train_prop
: Percentage for train (the rest being for the test)
batch_size
: Batch size
epochs
: Nb of epochs for training
fit_verbosity
is the verbosity during training : 0 = silent, 1 = progress bar, 2 = one line per epoch
Note : scale=.2, epoch=20 => 3'30s on a laptop
prepared_dataset = './data/mnist-noisy.h5'
dataset_seed = 123
scale = .1
latent_dim = 10
train_prop = .8
batch_size = 128
epochs = 20
fit_verbosity = 1
Override parameters (batch mode) - Just forget this cell
fidle.override('prepared_dataset', 'dataset_seed', 'scale', 'latent_dim')
fidle.override('train_prop', 'batch_size', 'epochs', 'fit_verbosity')
** Overrided parameters : ** scale : 1 latent_dim : 10 ** Overrided parameters : ** epochs : 20 fit_verbosity : 2
Step 2 - Retrieve datasetĀ¶
With our MNIST class, in one call, we can reload, rescale, shuffle and split our previously saved dataset :-)
clean_train,clean_test, noisy_train,noisy_test, _,_ = MNIST.reload_prepared_dataset(scale = scale,
train_prop = train_prop,
seed = dataset_seed,
shuffle = True,
filename=prepared_dataset )
Loaded. rescaled (1). Seeded (123)
Shuffled. splited (0.8).
clean_train shape is : (56000, 28, 28, 1) clean_test shape is : (14000, 28, 28, 1) noisy_train shape is : (56000, 28, 28, 1) noisy_test shape is : (14000, 28, 28, 1) class_train shape is : (56000,) class_test shape is : (14000,) Blake2b digest is : 3e97fec95d853b5a2615
Step 3 - Build modelsĀ¶
EncoderĀ¶
inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z = layers.Dense(latent_dim)(x)
encoder = keras.Model(inputs, z, name="encoder")
# encoder.summary()
DecoderĀ¶
inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(inputs, outputs, name="decoder")
# decoder.summary()
AEĀ¶
inputs = keras.Input(shape=(28, 28, 1))
latents = encoder(inputs)
outputs = decoder(latents)
ae = keras.Model(inputs,outputs, name="ae")
ae.compile(optimizer=keras.optimizers.Adam(), loss='binary_crossentropy')
Step 4 - TrainĀ¶
20' on a CPU
1'12 on a GPU (V100, IDRIS)
# ---- Callback : Images
#
fidle.utils.mkdir( run_dir + '/images')
filename = run_dir + '/images/image-{epoch:03d}-{i:02d}.jpg'
callback_images = ImagesCallback(filename, x=clean_test[:5], encoder=encoder,decoder=decoder)
chrono = fidle.Chrono()
chrono.start()
history = ae.fit(noisy_train, clean_train,
batch_size = batch_size,
epochs = epochs,
verbose = fit_verbosity,
validation_data = (noisy_test, clean_test),
callbacks = [ callback_images ] )
chrono.show()
Epoch 1/20
/lustre/fswork/projects/rech/mlh/uja62cb/local/fidle-k3/lib/python3.12/site-packages/keras/src/backend/common/backend_utils.py:91: UserWarning: You might experience inconsistencies across backends when calling conv transpose with kernel_size=3, stride=2, dilation_rate=1, padding=same, output_padding=1. warnings.warn(
1/1 āāāāāāāāāāāāāāāāāāāā 0s 10ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 11ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 6ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 7ms/step
438/438 - 5s - 13ms/step - loss: 0.2672 - val_loss: 0.2139
Epoch 2/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1910 - val_loss: 0.1807
Epoch 3/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1744 - val_loss: 0.1717
Epoch 4/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1691 - val_loss: 0.1699
Epoch 5/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1659 - val_loss: 0.1655
Epoch 6/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1637 - val_loss: 0.1641
Epoch 7/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
438/438 - 5s - 11ms/step - loss: 0.1620 - val_loss: 0.1623
Epoch 8/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1599 - val_loss: 0.1619
Epoch 9/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
438/438 - 5s - 11ms/step - loss: 0.1571 - val_loss: 0.1575
Epoch 10/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
438/438 - 5s - 11ms/step - loss: 0.1538 - val_loss: 0.1550
Epoch 11/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1523 - val_loss: 0.1547
Epoch 12/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1512 - val_loss: 0.1551
Epoch 13/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
438/438 - 5s - 11ms/step - loss: 0.1505 - val_loss: 0.1535
Epoch 14/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1498 - val_loss: 0.1531
Epoch 15/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1491 - val_loss: 0.1534
Epoch 16/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
438/438 - 5s - 11ms/step - loss: 0.1485 - val_loss: 0.1529
Epoch 17/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1481 - val_loss: 0.1528
Epoch 18/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1476 - val_loss: 0.1525
Epoch 19/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1472 - val_loss: 0.1524
Epoch 20/20
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 3ms/step
1/1 āāāāāāāāāāāāāāāāāāāā 0s 4ms/step
438/438 - 5s - 11ms/step - loss: 0.1469 - val_loss: 0.1529
Duration : 99.88 seconds
Save model
os.makedirs(f'{run_dir}/models', exist_ok=True)
encoder.save(f'{run_dir}/models/encoder.keras')
decoder.save(f'{run_dir}/models/decoder.keras')
Step 5 - HistoryĀ¶
fidle.scrawler.history(history, plot={'loss':['loss','val_loss']}, save_as='01-history')
Step 6 - Denoising progressĀ¶
imgs=[]
for epoch in range(0,epochs,2):
for i in range(5):
filename = run_dir + '/images/image-{epoch:03d}-{i:02d}.jpg'.format(epoch=epoch, i=i)
img = io.imread(filename)
imgs.append(img)
fidle.utils.subtitle('Real images (clean_test) :')
fidle.scrawler.images(clean_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as='02-original-real')
fidle.utils.subtitle('Noisy images (noisy_test) :')
fidle.scrawler.images(noisy_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as='03-original-noisy')
fidle.utils.subtitle('Evolution during the training period (denoised_test) :')
fidle.scrawler.images(imgs, None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, y_padding=0.1, save_as='04-learning')
fidle.utils.subtitle('Noisy images (noisy_test) :')
fidle.scrawler.images(noisy_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as=None)
fidle.utils.subtitle('Real images (clean_test) :')
fidle.scrawler.images(clean_test[:5], None, indices='all', columns=5, x_size=2,y_size=2, interpolation=None, save_as=None)
Real images (clean_test) :
Noisy images (noisy_test) :
Evolution during the training period (denoised_test) :
Noisy images (noisy_test) :
Real images (clean_test) :
encoder = keras.models.load_model(f'{run_dir}/models/encoder.keras')
decoder = keras.models.load_model(f'{run_dir}/models/decoder.keras')
inputs = keras.Input(shape=(28, 28, 1))
latents = encoder(inputs)
outputs = decoder(latents)
ae_reloaded = keras.Model(inputs,outputs, name="ae")
7.2 - Let's make a predictionĀ¶
denoised_test = ae_reloaded.predict(noisy_test, verbose=0)
print('Denoised images (denoised_test) shape : ',denoised_test.shape)
Denoised images (denoised_test) shape : (14000, 28, 28, 1)
7.3 - Denoised imagesĀ¶
i=random.randint(0,len(denoised_test)-8)
j=i+8
fidle.utils.subtitle('Noisy test images (input):')
fidle.scrawler.images(noisy_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='05-test-noisy')
fidle.utils.subtitle('Denoised images (output):')
fidle.scrawler.images(denoised_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='06-test-predict')
fidle.utils.subtitle('Real test images :')
fidle.scrawler.images(clean_test[i:j], None, indices='all', columns=8, x_size=2,y_size=2, interpolation=None, save_as='07-test-real')
Noisy test images (input):
Denoised images (output):
Real test images :
fidle.end()
End time : 22/12/24 21:26:11
Duration : 00:01:49 736ms
This notebook ends here :-)
https://fidle.cnrs.fr