[PLSHEEP3] - A DCGAN to Draw a Sheep, using Pytorch LightningĀ¶
"Draw me a sheep", revisited with a DCGAN, using Pytorch LightningObjectives :Ā¶
- Build and train a DCGAN model with the Quick Draw dataset
- Understanding DCGAN
The Quick draw dataset contains about 50.000.000 drawings, made by real people...
We are using a subset of 117.555 of Sheep drawings
To get the dataset : https://github.com/googlecreativelab/quickdraw-dataset
Datasets in numpy bitmap file : https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap
Sheep dataset : https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy (94.3 Mo)
What we're going to do :Ā¶
- Have a look to the dataset
- Defining a GAN model
- Build the model
- Train it
- Have a look of the results
import os
import sys
import shutil
import numpy as np
import torch
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
import fidle
from modules.QuickDrawDataModule import QuickDrawDataModule
from modules.GAN import GAN
from modules.WGANGP import WGANGP
from modules.Generators import *
from modules.Discriminators import *
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('PLSHEEP3')
FIDLE - Environment initialization
Version : 2.3.2 Run id : PLSHEEP3_2 Run dir : ./run/PLSHEEP3_2 Datasets dir : /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/datasets-fidle Start time : 22/12/24 21:41:56 Hostname : r3i6n0 (Linux) Tensorflow log level : Info + Warning + Error (=0) Update keras cache : False Update torch cache : False Save figs : ./run/PLSHEEP3_2/figs (True) numpy : 2.1.2 sklearn : 1.5.2 yaml : 6.0.2 matplotlib : 3.9.2 pandas : 2.2.3 torch : 2.5.0 torchvision : 0.20.0a0+afc54f7 lightning : 2.4.0 ** run_id has been overrided from PLSHEEP3 to PLSHEEP3_2
Few parametersĀ¶
scale=1, epochs=20 : Need 22' on a V100
latent_dim = 128
gan_name = 'WGANGP'
generator_name = 'Generator_2'
discriminator_name = 'Discriminator_3'
scale = 0.001
epochs = 4
num_workers = 2
lr = 0.0001
b1 = 0.5
b2 = 0.999
lambda_gp = 10
batch_size = 64
num_img = 48
fit_verbosity = 2
dataset_file = datasets_dir+'/QuickDraw/origine/sheep.npy'
data_shape = (28,28,1)
Override parameters (batch mode) - Just forget this cell
fidle.override('latent_dim', 'gan_name', 'generator_name', 'discriminator_name')
fidle.override('epochs', 'lr', 'b1', 'b2', 'batch_size', 'num_img', 'fit_verbosity')
fidle.override('dataset_file', 'data_shape', 'scale', 'num_workers' )
** Overrided parameters : ** gan_name : WGANGP generator_name : Generator_2 discriminator_name : Discriminator_3 ** Overrided parameters : ** epochs : 30 batch_size : 64 ** Overrided parameters : ** scale : 1 num_workers : 2
CleaningĀ¶
# You can comment these lines to keep each run...
shutil.rmtree(f'{run_dir}/figs', ignore_errors=True)
shutil.rmtree(f'{run_dir}/models', ignore_errors=True)
shutil.rmtree(f'{run_dir}/tb_logs', ignore_errors=True)
Step 2 - Get some nice dataĀ¶
Get a Nice DataModuleĀ¶
Our DataModule is defined in ./modules/QuickDrawDataModule.py
This is a LightningDataModule
dm = QuickDrawDataModule(dataset_file, scale, batch_size, num_workers=num_workers)
dm.setup()
---- QuickDrawDataModule initialization ---------------------------- with : scale=1 batch size=64 DataModule Setup : Original dataset shape : (126121, 784) Rescaled dataset shape : (126121, 784)
Final dataset shape : torch.Size([126121, 28, 28, 1]) Dataset loaded and ready.
Have a lookĀ¶
dl = dm.train_dataloader()
batch_data = next(iter(dl))
fidle.scrawler.images( batch_data.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
Step 3 - Get a nice GAN modelĀ¶
Our Generators are defined in ./modules/Generators.py
Our Discriminators are defined in ./modules/Discriminators.py
Our GANs are defined in :
Retrieve class by nameĀ¶
To be very flexible, we just specify class names as parameters.
The code below retrieves classes from their names.
module=sys.modules['__main__']
Generator_ = getattr(module, generator_name)
Discriminator_ = getattr(module, discriminator_name)
GAN_ = getattr(module, gan_name)
Basic test - Just to be sure it (could) works... ;-)Ā¶
generator = Generator_( latent_dim=latent_dim, data_shape=data_shape )
discriminator = Discriminator_( latent_dim=latent_dim, data_shape=data_shape )
print('\nFew tests :\n')
z = torch.randn(batch_size, latent_dim)
print('z size : ',z.size())
fake_img = generator.forward(z)
print('fake_img : ', fake_img.size())
p = discriminator.forward(fake_img)
print('pred fake : ', p.size())
print('batch_data : ',batch_data.size())
p = discriminator.forward(batch_data)
print('pred real : ', p.size())
print('\nShow fake images :')
nimg = fake_img.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
init generator 2 : 128 to (28, 28, 1) init discriminator 3 : (28, 28, 1) to sigmoid Few tests : z size : torch.Size([64, 128])
fake_img : torch.Size([64, 28, 28, 1]) pred fake : torch.Size([64, 1]) batch_data : torch.Size([64, 28, 28, 1]) pred real : torch.Size([64, 1]) Show fake images :
print('Fake images : ', fake_img.size())
print('Batch size : ', batch_data.size())
e = torch.distributions.uniform.Uniform(0, 1).sample([batch_size,1])
e = e[:None,None,None]
i = fake_img * e + (1-e)*batch_data
print('\ninterpolate images :')
nimg = i.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
Fake images : torch.Size([64, 28, 28, 1]) Batch size : torch.Size([64, 28, 28, 1]) interpolate images :
GAN modelĀ¶
To simplify our code, the GAN class is defined separately in the module ./modules/GAN.py
Passing the classe names for generator/discriminator by parameter allows to stay modular and to use the PL checkpoints.
gan = GAN_( data_shape = data_shape,
lr = lr,
b1 = b1,
b2 = b2,
lambda_gp = lambda_gp,
batch_size = batch_size,
latent_dim = latent_dim,
generator_name = generator_name,
discriminator_name = discriminator_name)
---- GAN initialization -------------------------------------------- Hyperarameters are : data_shape : (28, 28, 1) latent_dim : 128 lr : 0.0001 b1 : 0.5 b2 : 0.999 batch_size : 64 lambda_gp : 10 generator_name : Generator_2 discriminator_name : Discriminator_3 Submodels : init generator 2 : 128 to (28, 28, 1) init discriminator 3 : (28, 28, 1) to sigmoid
# ---- for tensorboard logs
#
logger = TensorBoardLogger( save_dir = f'{run_dir}',
name = 'tb_logs' )
log_dir = os.path.abspath(f'{run_dir}/tb_logs')
print('To access the logs with tensorboard, use this command line :')
print(f'tensorboard --logdir {log_dir}')
# ---- To save checkpoints
#
callback_checkpoints = ModelCheckpoint( dirpath = f'{run_dir}/models',
filename = 'bestModel',
save_top_k = 1,
save_last = True,
every_n_epochs = 1,
monitor = "g_loss")
To access the logs with tensorboard, use this command line : tensorboard --logdir /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/fidle/DCGAN.Lightning/run/PLSHEEP3_2/tb_logs
Train itĀ¶
trainer = Trainer(
accelerator = "auto",
max_epochs = epochs,
callbacks = [callback_checkpoints],
log_every_n_steps = batch_size,
logger = logger
)
trainer.fit(gan, dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
DataModule Setup : Original dataset shape : (126121, 784) Rescaled dataset shape : (126121, 784)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Final dataset shape : torch.Size([126121, 28, 28, 1]) Dataset loaded and ready.
| Name | Type | Params | Mode | In sizes | Out sizes -------------------------------------------------------------------------------------- 0 | generator | Generator_2 | 780 K | train | [2, 128] | [2, 28, 28, 1] 1 | discriminator | Discriminator_3 | 401 K | train | ? | ? -------------------------------------------------------------------------------------- 1.2 M Trainable params 0 Non-trainable params 1.2 M Total params 4.728 Total estimated model params size (MB) 35 Modules in train mode 0 Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.
`Trainer.fit` stopped: `max_epochs=30` reached.
Step 6 - Reload our best modelĀ¶
Note :
gan = GAN.load_from_checkpoint(f'{run_dir}/models/bestModel.ckpt')
---- GAN initialization -------------------------------------------- Hyperarameters are : data_shape : (28, 28, 1) latent_dim : 128 lr : 0.0001 b1 : 0.5 b2 : 0.999 batch_size : 64 generator_name : Generator_2 discriminator_name : Discriminator_3 lambda_gp : 10 Submodels : init generator 2 : 128 to (28, 28, 1) init discriminator 3 : (28, 28, 1) to sigmoid
nb_images = 96
z = torch.randn(nb_images, latent_dim)
print('z size : ',z.size())
if torch.cuda.is_available(): z=z.cuda()
fake_img = gan.generator.forward(z)
print('fake_img : ', fake_img.size())
nimg = fake_img.cpu().detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(nb_images), columns=12, x_size=1, y_size=1,
y_padding=0,spines_alpha=0, save_as='01-Sheeps')
z size : torch.Size([96, 128]) fake_img : torch.Size([96, 28, 28, 1])
fidle.end()
End time : 22/12/24 22:10:18
Duration : 00:28:22 197ms
This notebook ends here :-)
https://fidle.cnrs.fr