No description has been provided for this image

[PLSHEEP3] - A DCGAN to Draw a Sheep, using Pytorch Lightning¶

"Draw me a sheep", revisited with a DCGAN, using Pytorch Lightning

Objectives :¶

  • Build and train a DCGAN model with the Quick Draw dataset
  • Understanding DCGAN

The Quick draw dataset contains about 50.000.000 drawings, made by real people...
We are using a subset of 117.555 of Sheep drawings
To get the dataset : https://github.com/googlecreativelab/quickdraw-dataset
Datasets in numpy bitmap file : https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap
Sheep dataset : https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy (94.3 Mo)

What we're going to do :¶

  • Have a look to the dataset
  • Defining a GAN model
  • Build the model
  • Train it
  • Have a look of the results

Step 1 - Init and parameters¶

Python init¶

InĀ [1]:
import os
import sys
import shutil

import numpy as np
import torch
from lightning import Trainer
from lightning.pytorch.callbacks                        import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard              import TensorBoardLogger

import fidle

from modules.QuickDrawDataModule import QuickDrawDataModule

from modules.GAN                 import GAN
from modules.WGANGP              import WGANGP
from modules.Generators          import *
from modules.Discriminators      import *

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('PLSHEEP3')


FIDLE - Environment initialization

Version              : 2.3.2
Run id               : PLSHEEP3_2
Run dir              : ./run/PLSHEEP3_2
Datasets dir         : /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/datasets-fidle
Start time           : 22/12/24 21:41:56
Hostname             : r3i6n0 (Linux)
Tensorflow log level : Info + Warning + Error  (=0)
Update keras cache   : False
Update torch cache   : False
Save figs            : ./run/PLSHEEP3_2/figs (True)
numpy                : 2.1.2
sklearn              : 1.5.2
yaml                 : 6.0.2
matplotlib           : 3.9.2
pandas               : 2.2.3
torch                : 2.5.0
torchvision          : 0.20.0a0+afc54f7
lightning            : 2.4.0

** run_id has been overrided from PLSHEEP3 to PLSHEEP3_2

Few parameters¶

scale=1, epochs=20 : Need 22' on a V100

InĀ [2]:
latent_dim          = 128

gan_name            = 'WGANGP'
generator_name      = 'Generator_2'
discriminator_name  = 'Discriminator_3'
    
scale               = 0.001
epochs              = 4
num_workers         = 2
lr                  = 0.0001
b1                  = 0.5
b2                  = 0.999
lambda_gp           = 10
batch_size          = 64
num_img             = 48
fit_verbosity       = 2
    
dataset_file        = datasets_dir+'/QuickDraw/origine/sheep.npy' 
data_shape          = (28,28,1)

Override parameters (batch mode) - Just forget this cell

InĀ [3]:
fidle.override('latent_dim', 'gan_name', 'generator_name', 'discriminator_name')  
fidle.override('epochs', 'lr', 'b1', 'b2', 'batch_size', 'num_img', 'fit_verbosity')
fidle.override('dataset_file', 'data_shape', 'scale', 'num_workers' )
** Overrided parameters : **
gan_name             : WGANGP
generator_name       : Generator_2
discriminator_name   : Discriminator_3
** Overrided parameters : **
epochs               : 30
batch_size           : 64
** Overrided parameters : **
scale                : 1
num_workers          : 2

Cleaning¶

InĀ [4]:
# You can comment these lines to keep each run...
shutil.rmtree(f'{run_dir}/figs', ignore_errors=True)
shutil.rmtree(f'{run_dir}/models', ignore_errors=True)
shutil.rmtree(f'{run_dir}/tb_logs', ignore_errors=True)

Step 2 - Get some nice data¶

Get a Nice DataModule¶

Our DataModule is defined in ./modules/QuickDrawDataModule.py
This is a LightningDataModule

InĀ [5]:
dm = QuickDrawDataModule(dataset_file, scale, batch_size, num_workers=num_workers)
dm.setup()
---- QuickDrawDataModule initialization ----------------------------
with : scale=1  batch size=64

DataModule Setup :
Original dataset shape :  (126121, 784)
Rescaled dataset shape :  (126121, 784)
Final dataset shape    :  torch.Size([126121, 28, 28, 1])
Dataset loaded and ready.

Have a look¶

InĀ [6]:
dl         = dm.train_dataloader()
batch_data = next(iter(dl))

fidle.scrawler.images( batch_data.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, 
                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')
Saved: ./run/PLSHEEP3_2/figs/01-Sheeps
No description has been provided for this image

Step 3 - Get a nice GAN model¶

Our Generators are defined in ./modules/Generators.py
Our Discriminators are defined in ./modules/Discriminators.py

Our GANs are defined in :

  • ./modules/GAN.py
  • ./modules/WGANGP.py

Retrieve class by name¶

To be very flexible, we just specify class names as parameters.
The code below retrieves classes from their names.

InĀ [7]:
module=sys.modules['__main__']
Generator_     = getattr(module, generator_name)
Discriminator_ = getattr(module, discriminator_name)
GAN_           = getattr(module, gan_name)

Basic test - Just to be sure it (could) works... ;-)¶

InĀ [8]:
generator     = Generator_(     latent_dim=latent_dim, data_shape=data_shape )
discriminator = Discriminator_( latent_dim=latent_dim, data_shape=data_shape )

print('\nFew tests :\n')
z = torch.randn(batch_size, latent_dim)
print('z size        : ',z.size())

fake_img = generator.forward(z)
print('fake_img      : ', fake_img.size())

p = discriminator.forward(fake_img)
print('pred fake     : ', p.size())

print('batch_data    : ',batch_data.size())

p = discriminator.forward(batch_data)
print('pred real     : ', p.size())

print('\nShow fake images :')
nimg = fake_img.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, 
                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')
init generator 2         :  128  to  (28, 28, 1)
init discriminator 3     :  (28, 28, 1)  to sigmoid

Few tests :

z size        :  torch.Size([64, 128])
fake_img      :  torch.Size([64, 28, 28, 1])
pred fake     :  torch.Size([64, 1])
batch_data    :  torch.Size([64, 28, 28, 1])
pred real     :  torch.Size([64, 1])

Show fake images :
Saved: ./run/PLSHEEP3_2/figs/01-Sheeps
No description has been provided for this image
InĀ [9]:
print('Fake images : ', fake_img.size())
print('Batch size  : ', batch_data.size())
e = torch.distributions.uniform.Uniform(0, 1).sample([batch_size,1])
e = e[:None,None,None]
i = fake_img * e + (1-e)*batch_data

print('\ninterpolate images :')
nimg = i.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, 
                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')
Fake images :  torch.Size([64, 28, 28, 1])
Batch size  :  torch.Size([64, 28, 28, 1])

interpolate images :
Saved: ./run/PLSHEEP3_2/figs/01-Sheeps
No description has been provided for this image

GAN model¶

To simplify our code, the GAN class is defined separately in the module ./modules/GAN.py
Passing the classe names for generator/discriminator by parameter allows to stay modular and to use the PL checkpoints.

InĀ [10]:
gan = GAN_( data_shape          = data_shape,
            lr                  = lr,
            b1                  = b1,
            b2                  = b2,
            lambda_gp           = lambda_gp,
            batch_size          = batch_size, 
            latent_dim          = latent_dim, 
            generator_name      = generator_name, 
            discriminator_name  = discriminator_name)
---- GAN initialization --------------------------------------------
Hyperarameters are :
data_shape               : (28, 28, 1)
latent_dim               : 128
lr                       : 0.0001
b1                       : 0.5
b2                       : 0.999
batch_size               : 64
lambda_gp                : 10
generator_name           : Generator_2
discriminator_name       : Discriminator_3
Submodels :
init generator 2         :  128  to  (28, 28, 1)
init discriminator 3     :  (28, 28, 1)  to sigmoid

Step 5 - Train it !¶

Instantiate Callbacks, Logger & co.¶

More about :

  • Checkpoints
  • modelCheckpoint
InĀ [11]:
# ---- for tensorboard logs
#
logger       = TensorBoardLogger(       save_dir       = f'{run_dir}',
                                        name           = 'tb_logs'  )

log_dir = os.path.abspath(f'{run_dir}/tb_logs')
print('To access the logs with tensorboard, use this command line :')
print(f'tensorboard --logdir {log_dir}')

# ---- To save checkpoints
#
callback_checkpoints = ModelCheckpoint( dirpath        = f'{run_dir}/models', 
                                        filename       = 'bestModel', 
                                        save_top_k     = 1, 
                                        save_last      = True,
                                        every_n_epochs = 1, 
                                        monitor        = "g_loss")
To access the logs with tensorboard, use this command line :
tensorboard --logdir /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/fidle/DCGAN.Lightning/run/PLSHEEP3_2/tb_logs

Train it¶

InĀ [12]:
trainer = Trainer(
    accelerator        = "auto",
    max_epochs         = epochs,
    callbacks          = [callback_checkpoints],
    log_every_n_steps  = batch_size,
    logger             = logger
)

trainer.fit(gan, dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
DataModule Setup :
Original dataset shape :  (126121, 784)
Rescaled dataset shape :  (126121, 784)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Final dataset shape    :  torch.Size([126121, 28, 28, 1])
Dataset loaded and ready.
  | Name          | Type            | Params | Mode  | In sizes | Out sizes     
--------------------------------------------------------------------------------------
0 | generator     | Generator_2     | 780 K  | train | [2, 128] | [2, 28, 28, 1]
1 | discriminator | Discriminator_3 | 401 K  | train | ?        | ?             
--------------------------------------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.728     Total estimated model params size (MB)
35        Modules in train mode
0         Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.
`Trainer.fit` stopped: `max_epochs=30` reached.

Step 6 - Reload our best model¶

Note :

InĀ [13]:
gan = GAN.load_from_checkpoint(f'{run_dir}/models/bestModel.ckpt')
---- GAN initialization --------------------------------------------
Hyperarameters are :
data_shape               : (28, 28, 1)
latent_dim               : 128
lr                       : 0.0001
b1                       : 0.5
b2                       : 0.999
batch_size               : 64
generator_name           : Generator_2
discriminator_name       : Discriminator_3
lambda_gp                : 10
Submodels :
init generator 2         :  128  to  (28, 28, 1)
init discriminator 3     :  (28, 28, 1)  to sigmoid
InĀ [14]:
nb_images = 96

z = torch.randn(nb_images, latent_dim)
print('z size        : ',z.size())

if torch.cuda.is_available(): z=z.cuda()

fake_img = gan.generator.forward(z)
print('fake_img      : ', fake_img.size())

nimg = fake_img.cpu().detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(nb_images), columns=12, x_size=1, y_size=1, 
                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')
z size        :  torch.Size([96, 128])
fake_img      :  torch.Size([96, 28, 28, 1])
Saved: ./run/PLSHEEP3_2/figs/01-Sheeps
No description has been provided for this image
InĀ [15]:
fidle.end()

End time : 22/12/24 22:10:18
Duration : 00:28:22 197ms
This notebook ends here :-)
https://fidle.cnrs.fr


No description has been provided for this image