No description has been provided for this image

[LWINE1] - Wine quality prediction with a Dense Network (DNN)¶

Another example of regression, with a wine quality prediction, using PyTorch Lightning

Objectives :¶

  • Predict the quality of wines, based on their analysis
  • Understanding the principle and the architecture of a regression with a dense neural network with backup and restore of the trained model.

The Wine Quality datasets are made up of analyses of a large number of wines, with an associated quality (between 0 and 10)
This dataset is provide by :
Paulo Cortez, University of Minho, Guimarães, Portugal, http://www3.dsi.uminho.pt/pcortez
A. Cerdeira, F. Almeida, T. Matos and J. Reis, Viticulture Commission of the Vinho Verde Region(CVRVV), Porto, Portugal, @2009
This dataset can be retreive at University of California Irvine (UCI)

Due to privacy and logistic issues, only physicochemical and sensory variables are available
There is no data about grape types, wine brand, wine selling price, etc.

  • fixed acidity
  • volatile acidity
  • citric acid
  • residual sugar
  • chlorides
  • free sulfur dioxide
  • total sulfur dioxide
  • density
  • pH
  • sulphates
  • alcohol
  • quality (score between 0 and 10)

What we're going to do :¶

  • (Retrieve data)
  • (Preparing the data)
  • (Build a model)
  • Train and save the model
  • Restore saved model
  • Evaluate the model
  • Make some predictions

Step 1 - Import and init¶

In [1]:
# Import some packages
import os
import sys
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import lightning.pytorch as pl
import torch.nn.functional as F
import torchvision.transforms as T


from importlib import reload
from IPython.display import Markdown
from torch.utils.data import Dataset, DataLoader, random_split
from modules.progressbar import CustomTrainProgressBar
from modules.data_load import WineQualityDataset, Normalize, ToTensor
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from torchmetrics.functional.regression import mean_absolute_error, mean_squared_error

import fidle

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('LWINE1')


FIDLE - Environment initialization

Version              : 2.3.2
Run id               : LWINE1
Run dir              : ./run/LWINE1
Datasets dir         : /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/datasets-fidle
Start time           : 22/12/24 21:21:23
Hostname             : r3i6n3 (Linux)
Tensorflow log level : Info + Warning + Error  (=0)
Update keras cache   : False
Update torch cache   : False
Save figs            : ./run/LWINE1/figs (True)
numpy                : 2.1.2
sklearn              : 1.5.2
yaml                 : 6.0.2
matplotlib           : 3.9.2
pandas               : 2.2.3
torch                : 2.5.0
torchvision          : 0.20.0a0+afc54f7
lightning            : 2.4.0

Verbosity during training :

  • 0 = silent
  • 1 = progress bar
  • 2 = one line per epoch
In [2]:
fit_verbosity = 1
dataset_name  = 'winequality-red.csv'

Override parameters (batch mode) - Just forget this cell

In [3]:
fidle.override('fit_verbosity', 'dataset_name')
** Overrided parameters : **
fit_verbosity        : 2

Step 2 - Retrieve data¶

In [4]:
csv_file_path=f'{datasets_dir}/WineQuality/origine/{dataset_name}'
datasets=WineQualityDataset(csv_file_path)

display(datasets.data.head(5).style.format("{0:.2f}"))
print('Missing Data : ',datasets.data.isna().sum().sum(), '  Shape is : ', datasets.data.shape)
  fixed acidity volatile acidity citric acid residual sugar chlorides free sulfur dioxide total sulfur dioxide density pH sulphates alcohol quality
0 7.40 0.70 0.00 1.90 0.08 11.00 34.00 1.00 3.51 0.56 9.40 5.00
1 7.80 0.88 0.00 2.60 0.10 25.00 67.00 1.00 3.20 0.68 9.80 5.00
2 7.80 0.76 0.04 2.30 0.09 15.00 54.00 1.00 3.26 0.65 9.80 5.00
3 11.20 0.28 0.56 1.90 0.07 17.00 60.00 1.00 3.16 0.58 9.80 6.00
4 7.40 0.70 0.00 1.90 0.08 11.00 34.00 1.00 3.51 0.56 9.40 5.00
Missing Data :  0   Shape is :  (1599, 12)

Step 3 - Preparing the data¶

3.1 - Data normalization¶

Note :

  • All input features must be normalized.
  • To do this we will subtract the mean and divide by the standard deviation for each input features.
  • Then we convert numpy array features and target (quality) to torch tensor
In [5]:
transforms=T.Compose([Normalize(csv_file_path), ToTensor()])

dataset=WineQualityDataset(csv_file_path,transform=transforms)
In [6]:
display(Markdown("before normalization :"))
display(datasets[:]["features"])

print()

display(Markdown("After normalization :"))
display(dataset[:]["features"])

before normalization :

array([[ 7.4  ,  0.7  ,  0.   , ...,  3.51 ,  0.56 ,  9.4  ],
       [ 7.8  ,  0.88 ,  0.   , ...,  3.2  ,  0.68 ,  9.8  ],
       [ 7.8  ,  0.76 ,  0.04 , ...,  3.26 ,  0.65 ,  9.8  ],
       ...,
       [ 6.3  ,  0.51 ,  0.13 , ...,  3.42 ,  0.75 , 11.   ],
       [ 5.9  ,  0.645,  0.12 , ...,  3.57 ,  0.71 , 10.2  ],
       [ 6.   ,  0.31 ,  0.47 , ...,  3.39 ,  0.66 , 11.   ]],
      dtype=float32)

After normalization :

tensor([[-0.5282,  0.9616, -1.3910,  ...,  1.2882, -0.5790, -0.9599],
        [-0.2985,  1.9668, -1.3910,  ..., -0.7197,  0.1289, -0.5846],
        [-0.2985,  1.2967, -1.1857,  ..., -0.3311, -0.0481, -0.5846],
        ...,
        [-1.1600, -0.0995, -0.7237,  ...,  0.7053,  0.5419,  0.5415],
        [-1.3897,  0.6544, -0.7750,  ...,  1.6769,  0.3059, -0.2092],
        [-1.3323, -1.2165,  1.0217,  ...,  0.5110,  0.0109,  0.5415]])

3.2 - Split data¶

We will use 80% of the data for training and 20% for validation.
x will be the features data of the analysis and y the target (quality)

In [7]:
# ---- Split => train, test
#
data_train_len = int(len(dataset)*0.8)            # get 80 %
data_test_len  = len(dataset) -data_train_len     # test = all - train

# ---- Split => x,y with random_split
#
data_train_subset, data_test_subset=random_split(dataset, [data_train_len, data_test_len])       

                                                
x_train = data_train_subset[:]["features"]
y_train = data_train_subset[:]["quality" ]

x_test  = data_test_subset [:]["features"]
y_test  = data_test_subset [:]["quality" ]


print('Original data shape was : ',dataset.data.shape)
print('x_train : ',x_train.shape, 'y_train : ',y_train.shape)
print('x_test  : ',x_test.shape,  'y_test  : ',y_test.shape)
Original data shape was :  (1599, 12)
x_train :  torch.Size([1279, 11]) y_train :  torch.Size([1279, 1])
x_test  :  torch.Size([320, 11]) y_test  :  torch.Size([320, 1])

3.3 - For Training model use Dataloader¶

The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in minibatches, reshuffle the data at every epoch to reduce model overfitting. DataLoader is an iterable that abstracts this complexity for us in an easy API.

In [8]:
# train bacth data
train_loader= DataLoader(
  dataset=data_train_subset, 
  shuffle=True, 
  batch_size=20,
  num_workers=2  
)


# test bacth data
test_loader= DataLoader(
  dataset=data_test_subset, 
  shuffle=False, 
  batch_size=20,
  num_workers=2
)

Step 4 - Build a model¶

More informations about :

  • Optimizer
  • Activation
  • Loss
  • Metrics
In [9]:
class LitRegression(pl.LightningModule):
    
    def __init__(self,in_features=11):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features, 128),                               # hidden layer 1
            nn.ReLU(),                                                 # activation function 
            nn.Linear(128, 128),                                       # hidden layer 2
            nn.ReLU(),                                                 # activation function
            nn.Linear(128, 1))                                         # output layer   
    
    def forward(self, x):                                              # forward pass
        x = self.model(x)
        return x        

   
    # optimizer
    def configure_optimizers(self):                              
        optimizer = torch.optim.RMSprop(self.parameters(),lr=1e-4)
        return optimizer 
        
    
    def training_step(self, batch, batch_idx):
        # defines the train loop.
        x_features, y_target = batch["features"],batch["quality"]
        
        # forward pass
        y_pred = self.model(x_features)

        # loss function MSE
        loss   = F.mse_loss(y_pred, y_target)                           

        # metrics mae
        mae    = mean_absolute_error(y_pred,y_target) 

        # metrics mse
        mse    = mean_squared_error(y_pred,y_target)                    
        
        metrics= {"train_loss": loss, 
                   "train_mae" : mae, 
                   "train_mse" : mse
                  }
        
        # logs metrics for each training_step
        self.log_dict(metrics, 
                      on_step  = False,                     
                      on_epoch = True, 
                      logger   = True,
                      prog_bar = True,     
                     )
        return loss      

        
    def validation_step(self, batch, batch_idx):
        
        # defines the val loop.
        x_features, y_target = batch["features"],batch["quality"]

        # forward pass
        y_pred = self.model(x_features)

        # loss function MSE
        loss   = F.mse_loss(y_pred, y_target)                             

        # metrics
        mae    = mean_absolute_error(y_pred,y_target)

        # metrics
        mse    = mean_squared_error(y_pred,y_target)                          

        
        metrics= {"val_loss": loss, 
                   "val_mae" : mae, 
                   "val_mse" : mse
                  }
       
        # logs metrics for each validation_step 
        self.log_dict(metrics,                               
                      on_step  = False,                     
                      on_epoch = True, 
                      logger   = True,
                      prog_bar = True,     
                     )

        return metrics
            
   

5 - Train the model¶

5.1 - Get it¶

In [10]:
reg=LitRegression(in_features=11)
print(reg) 
LitRegression(
  (model): Sequential(
    (0): Linear(in_features=11, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=1, bias=True)
  )
)

5.2 - Add callback¶

In [11]:
os.makedirs('./run/models', exist_ok=True)
save_dir = "./run/models/"
filename ='best-model-{epoch}-{val_loss:.2f}'

savemodel_callback = pl.callbacks.ModelCheckpoint(dirpath=save_dir, 
                                                  filename=filename,
                                                  save_top_k=1, 
                                                  verbose=False, 
                                                  monitor="val_loss"
                                                 )

5.3 - Train it¶

In [12]:
# loggers data
os.makedirs(f'{run_dir}/logs',   mode=0o750, exist_ok=True)
logger= TensorBoardLogger(save_dir=f'{run_dir}/logs',name="reg_logs")
In [13]:
# train model
trainer = pl.Trainer(accelerator='auto',
                     max_epochs=100,
                     logger=logger,
                     num_sanity_val_steps=0,
                     callbacks=[savemodel_callback,CustomTrainProgressBar()])

trainer.fit(model=reg, train_dataloaders=train_loader, val_dataloaders=test_loader)
0it [00:00, ?it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/lustre/fshomisc/sup/hpe/pub/miniforge/24.9.0/envs/pytorch-gpu-2.5.0+py3.12.7/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/fidle/Wine.Lightning/run/models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | Sequential | 18.2 K | train
---------------------------------------------
18.2 K    Trainable params
0         Non-trainable params
18.2 K    Total params
0.073     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.
`Trainer.fit` stopped: `max_epochs=100` reached.

Step 6 - Evaluate¶

6.1 - Model evaluation¶

MAE = Mean Absolute Error (between the labels and predictions)
A mae equal to 3 represents an average error in prediction of $3k.

In [14]:
score=trainer.validate(model=reg, dataloaders=test_loader, verbose=False)

print('x_test / loss      : {:5.4f}'.format(score[0]['val_loss']))
print('x_test / mae       : {:5.4f}'.format(score[0]['val_mae']))
print('x_test / mse       : {:5.4f}'.format(score[0]['val_mse']))
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
x_test / loss      : 0.4484
x_test / mae       : 0.5249
x_test / mse       : 0.4484

6.2 - Training history¶

To access logs with tensorboad :

  • Under Docker, from a terminal launched via the jupyterlab launcher, use the following command:
    tensorboard --logdir <path-to-logs> --host 0.0.0.0
  • If you're not using Docker, from a terminal :
    tensorboard --logdir <path-to-logs>

Note: One tensorboard instance can be used simultaneously.

Step 7 - Restore a model :¶

7.1 - Reload model¶

In [15]:
# Load the model from a checkpoint
loaded_model = LitRegression.load_from_checkpoint(savemodel_callback.best_model_path)
print("Loaded:")
print(loaded_model)
Loaded:
LitRegression(
  (model): Sequential(
    (0): Linear(in_features=11, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=1, bias=True)
  )
)

7.2 - Evaluate it :¶

In [16]:
score=trainer.validate(model=loaded_model, dataloaders=test_loader, verbose=False)

print('x_test / loss      : {:5.4f}'.format(score[0]['val_loss']))
print('x_test / mae       : {:5.4f}'.format(score[0]['val_mae']))
print('x_test / mse       : {:5.4f}'.format(score[0]['val_mse']))
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
x_test / loss      : 0.4396
x_test / mae       : 0.5176
x_test / mse       : 0.4396

7.3 - Make a prediction¶

In [17]:
# ---- Pick n entries from our test set
n = 200
ii = np.random.randint(1,len(x_test),n)
x_sample = x_test[ii]
y_sample = y_test[ii]
In [18]:
# ---- Make a predictions :

# Sets the model in evaluation mode.
loaded_model.eval() 

# Perform inference using the loaded model
y_pred = loaded_model( x_sample )
In [19]:
# ---- Show it
print('Wine    Prediction   Real   Delta')
for i in range(n):
    pred   = y_pred[i][0].item()
    real   = y_sample[i][0].item()
    delta  = real-pred
    print(f'{i:03d}        {pred:.2f}       {real}      {delta:+.2f} ')
Wine    Prediction   Real   Delta
000        5.93       6.0      +0.07 
001        5.46       6.0      +0.54 
002        5.30       5.0      -0.30 
003        5.23       6.0      +0.77 
004        5.73       5.0      -0.73 
005        5.64       5.0      -0.64 
006        5.18       6.0      +0.82 
007        4.57       3.0      -1.57 
008        5.02       4.0      -1.02 
009        5.13       6.0      +0.87 
010        5.72       5.0      -0.72 
011        4.97       5.0      +0.03 
012        5.77       6.0      +0.23 
013        6.37       6.0      -0.37 
014        4.65       5.0      +0.35 
015        6.83       7.0      +0.17 
016        5.08       4.0      -1.08 
017        5.07       7.0      +1.93 
018        6.17       4.0      -2.17 
019        6.03       7.0      +0.97 
020        6.23       7.0      +0.77 
021        5.51       6.0      +0.49 
022        5.63       6.0      +0.37 
023        5.94       6.0      +0.06 
024        4.89       5.0      +0.11 
025        6.30       7.0      +0.70 
026        5.25       5.0      -0.25 
027        4.90       5.0      +0.10 
028        5.22       5.0      -0.22 
029        6.10       6.0      -0.10 
030        6.58       7.0      +0.42 
031        5.82       6.0      +0.18 
032        5.42       6.0      +0.58 
033        6.58       7.0      +0.42 
034        5.38       5.0      -0.38 
035        5.03       5.0      -0.03 
036        6.46       6.0      -0.46 
037        5.52       5.0      -0.52 
038        6.21       8.0      +1.79 
039        5.57       7.0      +1.43 
040        5.71       6.0      +0.29 
041        5.80       5.0      -0.80 
042        5.20       6.0      +0.80 
043        5.11       5.0      -0.11 
044        5.09       5.0      -0.09 
045        5.94       5.0      -0.94 
046        6.50       6.0      -0.50 
047        4.73       4.0      -0.73 
048        6.53       7.0      +0.47 
049        5.64       5.0      -0.64 
050        6.53       7.0      +0.47 
051        5.22       6.0      +0.78 
052        6.50       6.0      -0.50 
053        6.02       6.0      -0.02 
054        5.29       5.0      -0.29 
055        5.92       6.0      +0.08 
056        5.30       6.0      +0.70 
057        5.45       6.0      +0.55 
058        5.64       6.0      +0.36 
059        5.11       5.0      -0.11 
060        4.97       3.0      -1.97 
061        6.54       6.0      -0.54 
062        5.33       6.0      +0.67 
063        5.20       5.0      -0.20 
064        6.92       7.0      +0.08 
065        5.09       5.0      -0.09 
066        6.55       7.0      +0.45 
067        5.99       6.0      +0.01 
068        5.26       5.0      -0.26 
069        6.79       6.0      -0.79 
070        4.97       3.0      -1.97 
071        6.79       6.0      -0.79 
072        5.98       7.0      +1.02 
073        5.36       5.0      -0.36 
074        5.06       5.0      -0.06 
075        6.02       6.0      -0.02 
076        5.02       6.0      +0.98 
077        4.82       5.0      +0.18 
078        6.69       7.0      +0.31 
079        5.03       5.0      -0.03 
080        6.03       6.0      -0.03 
081        5.25       5.0      -0.25 
082        6.23       7.0      +0.77 
083        4.98       5.0      +0.02 
084        5.34       5.0      -0.34 
085        5.74       6.0      +0.26 
086        4.65       5.0      +0.35 
087        5.51       6.0      +0.49 
088        5.06       5.0      -0.06 
089        5.34       5.0      -0.34 
090        5.19       5.0      -0.19 
091        6.03       6.0      -0.03 
092        5.25       5.0      -0.25 
093        4.95       6.0      +1.05 
094        6.44       6.0      -0.44 
095        5.41       5.0      -0.41 
096        5.17       5.0      -0.17 
097        5.73       5.0      -0.73 
098        5.59       5.0      -0.59 
099        6.82       6.0      -0.82 
100        6.50       6.0      -0.50 
101        4.71       5.0      +0.29 
102        5.40       6.0      +0.60 
103        5.86       5.0      -0.86 
104        6.93       7.0      +0.07 
105        5.98       7.0      +1.02 
106        5.13       5.0      -0.13 
107        5.65       5.0      -0.65 
108        5.14       5.0      -0.14 
109        5.34       5.0      -0.34 
110        4.99       5.0      +0.01 
111        5.38       5.0      -0.38 
112        6.57       7.0      +0.43 
113        5.20       5.0      -0.20 
114        5.38       5.0      -0.38 
115        5.84       6.0      +0.16 
116        5.25       5.0      -0.25 
117        4.87       5.0      +0.13 
118        6.69       7.0      +0.31 
119        6.21       6.0      -0.21 
120        5.61       5.0      -0.61 
121        5.58       6.0      +0.42 
122        6.58       7.0      +0.42 
123        6.32       8.0      +1.68 
124        5.59       5.0      -0.59 
125        6.07       6.0      -0.07 
126        5.35       7.0      +1.65 
127        4.71       5.0      +0.29 
128        6.61       6.0      -0.61 
129        6.37       7.0      +0.63 
130        6.07       6.0      -0.07 
131        5.46       5.0      -0.46 
132        5.13       5.0      -0.13 
133        6.03       7.0      +0.97 
134        5.22       5.0      -0.22 
135        5.34       5.0      -0.34 
136        5.34       5.0      -0.34 
137        5.25       5.0      -0.25 
138        5.52       6.0      +0.48 
139        5.08       5.0      -0.08 
140        5.36       7.0      +1.64 
141        5.08       5.0      -0.08 
142        6.03       7.0      +0.97 
143        5.14       5.0      -0.14 
144        5.38       5.0      -0.38 
145        6.04       6.0      -0.04 
146        5.09       5.0      -0.09 
147        5.44       5.0      -0.44 
148        5.24       5.0      -0.24 
149        5.46       6.0      +0.54 
150        5.26       4.0      -1.26 
151        5.43       5.0      -0.43 
152        5.38       6.0      +0.62 
153        6.50       6.0      -0.50 
154        5.66       4.0      -1.66 
155        5.92       7.0      +1.08 
156        6.56       7.0      +0.44 
157        4.95       5.0      +0.05 
158        6.23       7.0      +0.77 
159        4.95       6.0      +1.05 
160        4.22       5.0      +0.78 
161        5.24       5.0      -0.24 
162        4.56       5.0      +0.44 
163        5.38       6.0      +0.62 
164        6.93       7.0      +0.07 
165        5.09       4.0      -1.09 
166        5.22       6.0      +0.78 
167        4.83       4.0      -0.83 
168        5.55       5.0      -0.55 
169        5.68       6.0      +0.32 
170        5.11       5.0      -0.11 
171        6.44       6.0      -0.44 
172        5.98       7.0      +1.02 
173        6.34       7.0      +0.66 
174        5.22       5.0      -0.22 
175        5.81       6.0      +0.19 
176        4.45       4.0      -0.45 
177        5.59       5.0      -0.59 
178        5.87       5.0      -0.87 
179        5.02       6.0      +0.98 
180        5.20       6.0      +0.80 
181        6.35       7.0      +0.65 
182        5.57       6.0      +0.43 
183        4.97       3.0      -1.97 
184        5.13       5.0      -0.13 
185        6.29       7.0      +0.71 
186        5.34       5.0      -0.34 
187        5.64       6.0      +0.36 
188        6.56       7.0      +0.44 
189        5.49       5.0      -0.49 
190        4.87       4.0      -0.87 
191        4.82       5.0      +0.18 
192        6.50       6.0      -0.50 
193        6.56       7.0      +0.44 
194        6.17       6.0      -0.17 
195        5.55       5.0      -0.55 
196        5.76       6.0      +0.24 
197        5.86       5.0      -0.86 
198        5.74       6.0      +0.26 
199        6.58       7.0      +0.42 
In [20]:
fidle.end()

End time : 22/12/24 21:22:40
Duration : 00:01:17 141ms
This notebook ends here :-)
https://fidle.cnrs.fr


No description has been provided for this image

In [ ]: