[K3WINE1] - Wine quality prediction with a Dense Network (DNN)¶
Another example of regression, with a wine quality prediction, using Keras 3 and PyTorchObjectives :¶
- Predict the quality of wines, based on their analysis
- Understanding the principle and the architecture of a regression with a dense neural network with backup and restore of the trained model.
The Wine Quality datasets are made up of analyses of a large number of wines, with an associated quality (between 0 and 10)
This dataset is provide by :
Paulo Cortez, University of Minho, Guimarães, Portugal, http://www3.dsi.uminho.pt/pcortez
A. Cerdeira, F. Almeida, T. Matos and J. Reis, Viticulture Commission of the Vinho Verde Region(CVRVV), Porto, Portugal, @2009
This dataset can be retreive at University of California Irvine (UCI)
Due to privacy and logistic issues, only physicochemical and sensory variables are available
There is no data about grape types, wine brand, wine selling price, etc.
- fixed acidity
- volatile acidity
- citric acid
- residual sugar
- chlorides
- free sulfur dioxide
- total sulfur dioxide
- density
- pH
- sulphates
- alcohol
- quality (score between 0 and 10)
What we're going to do :¶
- (Retrieve data)
- (Preparing the data)
- (Build a model)
- Train and save the model
- Restore saved model
- Evaluate the model
- Make some predictions
Step 1 - Import and init¶
import os
os.environ['KERAS_BACKEND'] = 'torch'
import keras
import numpy as np
import pandas as pd
import fidle
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3WINE1')
FIDLE - Environment initialization
Version : 2.3.0 Run id : K3WINE1 Run dir : ./run/K3WINE1 Datasets dir : /gpfswork/rech/mlh/uja62cb/fidle-project/datasets-fidle Start time : 03/03/24 21:04:35 Hostname : r6i1n1 (Linux) Tensorflow log level : Warning + Error (=1) Update keras cache : False Update torch cache : False Save figs : ./run/K3WINE1/figs (True) keras : 3.0.4 numpy : 1.24.4 sklearn : 1.3.2 yaml : 6.0.1 matplotlib : 3.8.2 pandas : 2.1.3 torch : 2.1.1
Verbosity during training :
- 0 = silent
- 1 = progress bar
- 2 = one line per epoch
fit_verbosity = 1
dataset_name = 'winequality-red.csv'
Override parameters (batch mode) - Just forget this cell
fidle.override('fit_verbosity', 'dataset_name')
** Overrided parameters : ** fit_verbosity : 2
Step 2 - Retrieve data¶
data = pd.read_csv(f'{datasets_dir}/WineQuality/origine/{dataset_name}', header=0,sep=';')
display(data.head(5).style.format("{0:.2f}"))
print('Missing Data : ',data.isna().sum().sum(), ' Shape is : ', data.shape)
fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | quality | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 7.40 | 0.70 | 0.00 | 1.90 | 0.08 | 11.00 | 34.00 | 1.00 | 3.51 | 0.56 | 9.40 | 5.00 |
1 | 7.80 | 0.88 | 0.00 | 2.60 | 0.10 | 25.00 | 67.00 | 1.00 | 3.20 | 0.68 | 9.80 | 5.00 |
2 | 7.80 | 0.76 | 0.04 | 2.30 | 0.09 | 15.00 | 54.00 | 1.00 | 3.26 | 0.65 | 9.80 | 5.00 |
3 | 11.20 | 0.28 | 0.56 | 1.90 | 0.07 | 17.00 | 60.00 | 1.00 | 3.16 | 0.58 | 9.80 | 6.00 |
4 | 7.40 | 0.70 | 0.00 | 1.90 | 0.08 | 11.00 | 34.00 | 1.00 | 3.51 | 0.56 | 9.40 | 5.00 |
Missing Data : 0 Shape is : (1599, 12)
# ---- Split => train, test
#
data = data.sample(frac=1., axis=0) # Shuffle
data_train = data.sample(frac=0.8, axis=0) # get 80 %
data_test = data.drop(data_train.index) # test = all - train
# ---- Split => x,y (medv is price)
#
x_train = data_train.drop('quality', axis=1)
y_train = data_train['quality']
x_test = data_test.drop('quality', axis=1)
y_test = data_test['quality']
print('Original data shape was : ',data.shape)
print('x_train : ',x_train.shape, 'y_train : ',y_train.shape)
print('x_test : ',x_test.shape, 'y_test : ',y_test.shape)
Original data shape was : (1599, 12) x_train : (1279, 11) y_train : (1279,) x_test : (320, 11) y_test : (320,)
3.2 - Data normalization¶
Note :
- All input data must be normalized, train and test.
- To do this we will subtract the mean and divide by the standard deviation.
- But test data should not be used in any way, even for normalization.
- The mean and the standard deviation will therefore only be calculated with the train data.
display(x_train.describe().style.format("{0:.2f}").set_caption("Before normalization :"))
mean = x_train.mean()
std = x_train.std()
x_train = (x_train - mean) / std
x_test = (x_test - mean) / std
display(x_train.describe().style.format("{0:.2f}").set_caption("After normalization :"))
# Convert ou DataFrame to numpy array
x_train, y_train = np.array(x_train), np.array(y_train)
x_test, y_test = np.array(x_test), np.array(y_test)
fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | |
---|---|---|---|---|---|---|---|---|---|---|---|
count | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 |
mean | 8.30 | 0.53 | 0.27 | 2.51 | 0.09 | 15.84 | 46.67 | 1.00 | 3.31 | 0.66 | 10.42 |
std | 1.72 | 0.18 | 0.19 | 1.34 | 0.05 | 10.32 | 33.33 | 0.00 | 0.15 | 0.17 | 1.07 |
min | 4.70 | 0.12 | 0.00 | 0.90 | 0.01 | 1.00 | 6.00 | 0.99 | 2.86 | 0.33 | 8.40 |
25% | 7.10 | 0.39 | 0.09 | 1.90 | 0.07 | 7.00 | 22.00 | 1.00 | 3.21 | 0.55 | 9.50 |
50% | 7.90 | 0.52 | 0.25 | 2.20 | 0.08 | 14.00 | 38.00 | 1.00 | 3.31 | 0.62 | 10.10 |
75% | 9.20 | 0.64 | 0.42 | 2.60 | 0.09 | 21.50 | 62.00 | 1.00 | 3.41 | 0.73 | 11.10 |
max | 15.60 | 1.58 | 0.79 | 15.50 | 0.61 | 68.00 | 289.00 | 1.00 | 4.01 | 1.98 | 14.00 |
fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | |
---|---|---|---|---|---|---|---|---|---|---|---|
count | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 |
mean | -0.00 | -0.00 | -0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 |
std | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
min | -2.10 | -2.26 | -1.39 | -1.21 | -1.65 | -1.44 | -1.22 | -3.56 | -2.95 | -1.93 | -1.88 |
25% | -0.70 | -0.77 | -0.92 | -0.46 | -0.35 | -0.86 | -0.74 | -0.61 | -0.68 | -0.64 | -0.86 |
50% | -0.23 | -0.05 | -0.10 | -0.23 | -0.17 | -0.18 | -0.26 | 0.02 | -0.03 | -0.23 | -0.30 |
75% | 0.52 | 0.62 | 0.78 | 0.07 | 0.07 | 0.55 | 0.46 | 0.59 | 0.62 | 0.42 | 0.64 |
max | 4.25 | 5.83 | 2.69 | 9.72 | 11.57 | 5.06 | 7.27 | 3.71 | 4.51 | 7.76 | 3.34 |
def get_model_v1(shape):
model = keras.models.Sequential()
model.add(keras.layers.Input(shape, name="InputLayer"))
model.add(keras.layers.Dense(64, activation='relu', name='Dense_n1'))
model.add(keras.layers.Dense(64, activation='relu', name='Dense_n2'))
model.add(keras.layers.Dense(1, name='Output'))
model.compile(optimizer = 'rmsprop',
loss = 'mse',
metrics = ['mae', 'mse'] )
return model
model=get_model_v1( (11,) )
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ Dense_n1 (Dense) │ (None, 64) │ 768 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ Dense_n2 (Dense) │ (None, 64) │ 4,160 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ Output (Dense) │ (None, 1) │ 65 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 4,993 (19.50 KB)
Trainable params: 4,993 (19.50 KB)
Non-trainable params: 0 (0.00 B)
5.2 - Add callback¶
os.makedirs('./run/models', mode=0o750, exist_ok=True)
save_dir = "./run/models/best_model.keras"
savemodel_callback = keras.callbacks.ModelCheckpoint( filepath=save_dir, monitor='val_mae', mode='max', save_best_only=True)
5.3 - Train it¶
history = model.fit(x_train,
y_train,
epochs = 100,
batch_size = 10,
verbose = fit_verbosity,
validation_data = (x_test, y_test),
callbacks = [savemodel_callback])
Epoch 1/100 128/128 - 1s - 8ms/step - loss: 5.7515 - mae: 1.8792 - mse: 5.7515 - val_loss: 2.1294 - val_mae: 1.1600 - val_mse: 2.1294 Epoch 2/100 128/128 - 1s - 6ms/step - loss: 1.4692 - mae: 0.9380 - mse: 1.4692 - val_loss: 1.2868 - val_mae: 0.8921 - val_mse: 1.2868 Epoch 3/100 128/128 - 1s - 6ms/step - loss: 0.9939 - mae: 0.7782 - mse: 0.9939 - val_loss: 1.1022 - val_mae: 0.8124 - val_mse: 1.1022 Epoch 4/100 128/128 - 1s - 6ms/step - loss: 0.7318 - mae: 0.6702 - mse: 0.7318 - val_loss: 0.8095 - val_mae: 0.6965 - val_mse: 0.8095 Epoch 5/100 128/128 - 1s - 6ms/step - loss: 0.5927 - mae: 0.5962 - mse: 0.5927 - val_loss: 0.7140 - val_mae: 0.6409 - val_mse: 0.7140 Epoch 6/100 128/128 - 1s - 6ms/step - loss: 0.5002 - mae: 0.5493 - mse: 0.5002 - val_loss: 0.6444 - val_mae: 0.6203 - val_mse: 0.6444 Epoch 7/100 128/128 - 1s - 6ms/step - loss: 0.4577 - mae: 0.5202 - mse: 0.4577 - val_loss: 0.6213 - val_mae: 0.6069 - val_mse: 0.6213 Epoch 8/100 128/128 - 1s - 6ms/step - loss: 0.4335 - mae: 0.5094 - mse: 0.4335 - val_loss: 0.5809 - val_mae: 0.5781 - val_mse: 0.5809 Epoch 9/100 128/128 - 1s - 6ms/step - loss: 0.4129 - mae: 0.5016 - mse: 0.4129 - val_loss: 0.6212 - val_mae: 0.5983 - val_mse: 0.6212 Epoch 10/100 128/128 - 1s - 6ms/step - loss: 0.3850 - mae: 0.4786 - mse: 0.3850 - val_loss: 0.5966 - val_mae: 0.5826 - val_mse: 0.5966 Epoch 11/100 128/128 - 1s - 6ms/step - loss: 0.3738 - mae: 0.4696 - mse: 0.3738 - val_loss: 0.5671 - val_mae: 0.5767 - val_mse: 0.5671 Epoch 12/100 128/128 - 1s - 6ms/step - loss: 0.3616 - mae: 0.4648 - mse: 0.3616 - val_loss: 0.6396 - val_mae: 0.5961 - val_mse: 0.6396 Epoch 13/100 128/128 - 1s - 6ms/step - loss: 0.3468 - mae: 0.4535 - mse: 0.3468 - val_loss: 0.5981 - val_mae: 0.5893 - val_mse: 0.5981 Epoch 14/100 128/128 - 1s - 6ms/step - loss: 0.3442 - mae: 0.4525 - mse: 0.3442 - val_loss: 0.5377 - val_mae: 0.5487 - val_mse: 0.5377 Epoch 15/100 128/128 - 1s - 6ms/step - loss: 0.3357 - mae: 0.4475 - mse: 0.3357 - val_loss: 0.5743 - val_mae: 0.5749 - val_mse: 0.5743 Epoch 16/100 128/128 - 1s - 6ms/step - loss: 0.3274 - mae: 0.4446 - mse: 0.3274 - val_loss: 0.5678 - val_mae: 0.5627 - val_mse: 0.5678 Epoch 17/100 128/128 - 1s - 6ms/step - loss: 0.3202 - mae: 0.4423 - mse: 0.3202 - val_loss: 0.5685 - val_mae: 0.5740 - val_mse: 0.5685 Epoch 18/100 128/128 - 1s - 6ms/step - loss: 0.3114 - mae: 0.4281 - mse: 0.3114 - val_loss: 0.5313 - val_mae: 0.5509 - val_mse: 0.5313 Epoch 19/100 128/128 - 1s - 6ms/step - loss: 0.3112 - mae: 0.4309 - mse: 0.3112 - val_loss: 0.5182 - val_mae: 0.5385 - val_mse: 0.5182 Epoch 20/100 128/128 - 1s - 6ms/step - loss: 0.3028 - mae: 0.4286 - mse: 0.3028 - val_loss: 0.5302 - val_mae: 0.5464 - val_mse: 0.5302 Epoch 21/100 128/128 - 1s - 6ms/step - loss: 0.3018 - mae: 0.4259 - mse: 0.3018 - val_loss: 0.6004 - val_mae: 0.5717 - val_mse: 0.6004 Epoch 22/100 128/128 - 1s - 6ms/step - loss: 0.2938 - mae: 0.4208 - mse: 0.2938 - val_loss: 0.5669 - val_mae: 0.5825 - val_mse: 0.5669 Epoch 23/100 128/128 - 1s - 6ms/step - loss: 0.2913 - mae: 0.4169 - mse: 0.2913 - val_loss: 0.5728 - val_mae: 0.5746 - val_mse: 0.5728 Epoch 24/100 128/128 - 1s - 6ms/step - loss: 0.2834 - mae: 0.4112 - mse: 0.2834 - val_loss: 0.5568 - val_mae: 0.5615 - val_mse: 0.5568 Epoch 25/100 128/128 - 1s - 6ms/step - loss: 0.2808 - mae: 0.4102 - mse: 0.2808 - val_loss: 0.5383 - val_mae: 0.5429 - val_mse: 0.5383 Epoch 26/100 128/128 - 1s - 6ms/step - loss: 0.2812 - mae: 0.4088 - mse: 0.2812 - val_loss: 0.5249 - val_mae: 0.5527 - val_mse: 0.5249 Epoch 27/100 128/128 - 1s - 6ms/step - loss: 0.2751 - mae: 0.4067 - mse: 0.2751 - val_loss: 0.5515 - val_mae: 0.5380 - val_mse: 0.5515 Epoch 28/100 128/128 - 1s - 6ms/step - loss: 0.2771 - mae: 0.4060 - mse: 0.2771 - val_loss: 0.5636 - val_mae: 0.5703 - val_mse: 0.5636 Epoch 29/100 128/128 - 1s - 6ms/step - loss: 0.2698 - mae: 0.4040 - mse: 0.2698 - val_loss: 0.5937 - val_mae: 0.5774 - val_mse: 0.5937 Epoch 30/100 128/128 - 1s - 6ms/step - loss: 0.2684 - mae: 0.4017 - mse: 0.2684 - val_loss: 0.5045 - val_mae: 0.5336 - val_mse: 0.5045 Epoch 31/100 128/128 - 1s - 6ms/step - loss: 0.2650 - mae: 0.3961 - mse: 0.2650 - val_loss: 0.5497 - val_mae: 0.5611 - val_mse: 0.5497 Epoch 32/100 128/128 - 1s - 6ms/step - loss: 0.2551 - mae: 0.3946 - mse: 0.2551 - val_loss: 0.5473 - val_mae: 0.5580 - val_mse: 0.5473 Epoch 33/100 128/128 - 1s - 6ms/step - loss: 0.2564 - mae: 0.3882 - mse: 0.2564 - val_loss: 0.5207 - val_mae: 0.5388 - val_mse: 0.5207 Epoch 34/100 128/128 - 1s - 6ms/step - loss: 0.2558 - mae: 0.3916 - mse: 0.2558 - val_loss: 0.5455 - val_mae: 0.5405 - val_mse: 0.5455 Epoch 35/100 128/128 - 1s - 6ms/step - loss: 0.2489 - mae: 0.3853 - mse: 0.2489 - val_loss: 0.5034 - val_mae: 0.5365 - val_mse: 0.5034 Epoch 36/100 128/128 - 1s - 6ms/step - loss: 0.2482 - mae: 0.3844 - mse: 0.2482 - val_loss: 0.4982 - val_mae: 0.5308 - val_mse: 0.4982 Epoch 37/100 128/128 - 1s - 6ms/step - loss: 0.2479 - mae: 0.3808 - mse: 0.2479 - val_loss: 0.5971 - val_mae: 0.5776 - val_mse: 0.5971 Epoch 38/100 128/128 - 1s - 6ms/step - loss: 0.2372 - mae: 0.3708 - mse: 0.2372 - val_loss: 0.6684 - val_mae: 0.6092 - val_mse: 0.6684 Epoch 39/100 128/128 - 1s - 6ms/step - loss: 0.2407 - mae: 0.3730 - mse: 0.2407 - val_loss: 0.6313 - val_mae: 0.6212 - val_mse: 0.6313 Epoch 40/100 128/128 - 1s - 6ms/step - loss: 0.2311 - mae: 0.3633 - mse: 0.2311 - val_loss: 0.5844 - val_mae: 0.5872 - val_mse: 0.5844 Epoch 41/100 128/128 - 1s - 6ms/step - loss: 0.2287 - mae: 0.3691 - mse: 0.2287 - val_loss: 0.5722 - val_mae: 0.5422 - val_mse: 0.5722 Epoch 42/100 128/128 - 1s - 6ms/step - loss: 0.2291 - mae: 0.3610 - mse: 0.2291 - val_loss: 0.5546 - val_mae: 0.5565 - val_mse: 0.5546 Epoch 43/100 128/128 - 1s - 6ms/step - loss: 0.2332 - mae: 0.3738 - mse: 0.2332 - val_loss: 0.5095 - val_mae: 0.5301 - val_mse: 0.5095 Epoch 44/100 128/128 - 1s - 6ms/step - loss: 0.2280 - mae: 0.3655 - mse: 0.2280 - val_loss: 0.5166 - val_mae: 0.5251 - val_mse: 0.5166 Epoch 45/100 128/128 - 1s - 6ms/step - loss: 0.2265 - mae: 0.3622 - mse: 0.2265 - val_loss: 0.5281 - val_mae: 0.5556 - val_mse: 0.5281 Epoch 46/100 128/128 - 1s - 6ms/step - loss: 0.2259 - mae: 0.3616 - mse: 0.2259 - val_loss: 0.6534 - val_mae: 0.5948 - val_mse: 0.6534 Epoch 47/100 128/128 - 1s - 6ms/step - loss: 0.2233 - mae: 0.3641 - mse: 0.2233 - val_loss: 0.5343 - val_mae: 0.5293 - val_mse: 0.5343 Epoch 48/100 128/128 - 1s - 6ms/step - loss: 0.2219 - mae: 0.3596 - mse: 0.2219 - val_loss: 0.5366 - val_mae: 0.5565 - val_mse: 0.5366 Epoch 49/100 128/128 - 1s - 6ms/step - loss: 0.2251 - mae: 0.3613 - mse: 0.2251 - val_loss: 0.5248 - val_mae: 0.5365 - val_mse: 0.5248 Epoch 50/100 128/128 - 1s - 6ms/step - loss: 0.2169 - mae: 0.3551 - mse: 0.2169 - val_loss: 0.5192 - val_mae: 0.5376 - val_mse: 0.5192 Epoch 51/100 128/128 - 1s - 6ms/step - loss: 0.2100 - mae: 0.3513 - mse: 0.2100 - val_loss: 0.5145 - val_mae: 0.5444 - val_mse: 0.5145 Epoch 52/100 128/128 - 1s - 6ms/step - loss: 0.2096 - mae: 0.3514 - mse: 0.2096 - val_loss: 0.5129 - val_mae: 0.5399 - val_mse: 0.5129 Epoch 53/100 128/128 - 1s - 6ms/step - loss: 0.2136 - mae: 0.3530 - mse: 0.2136 - val_loss: 0.5305 - val_mae: 0.5332 - val_mse: 0.5305 Epoch 54/100 128/128 - 1s - 6ms/step - loss: 0.2124 - mae: 0.3495 - mse: 0.2124 - val_loss: 0.5945 - val_mae: 0.5623 - val_mse: 0.5945 Epoch 55/100 128/128 - 1s - 6ms/step - loss: 0.2081 - mae: 0.3466 - mse: 0.2081 - val_loss: 0.5589 - val_mae: 0.5691 - val_mse: 0.5589 Epoch 56/100 128/128 - 1s - 6ms/step - loss: 0.2052 - mae: 0.3442 - mse: 0.2052 - val_loss: 0.5383 - val_mae: 0.5618 - val_mse: 0.5383 Epoch 57/100 128/128 - 1s - 6ms/step - loss: 0.1996 - mae: 0.3397 - mse: 0.1996 - val_loss: 0.5082 - val_mae: 0.5250 - val_mse: 0.5082 Epoch 58/100 128/128 - 1s - 6ms/step - loss: 0.2058 - mae: 0.3442 - mse: 0.2058 - val_loss: 0.5457 - val_mae: 0.5396 - val_mse: 0.5457 Epoch 59/100 128/128 - 1s - 6ms/step - loss: 0.1963 - mae: 0.3413 - mse: 0.1963 - val_loss: 0.5406 - val_mae: 0.5398 - val_mse: 0.5406 Epoch 60/100 128/128 - 1s - 6ms/step - loss: 0.1983 - mae: 0.3387 - mse: 0.1983 - val_loss: 0.4792 - val_mae: 0.5203 - val_mse: 0.4792 Epoch 61/100 128/128 - 1s - 6ms/step - loss: 0.1948 - mae: 0.3373 - mse: 0.1948 - val_loss: 0.5468 - val_mae: 0.5519 - val_mse: 0.5468 Epoch 62/100 128/128 - 1s - 6ms/step - loss: 0.1905 - mae: 0.3300 - mse: 0.1905 - val_loss: 0.5687 - val_mae: 0.5618 - val_mse: 0.5687 Epoch 63/100 128/128 - 1s - 6ms/step - loss: 0.2000 - mae: 0.3346 - mse: 0.2000 - val_loss: 0.5241 - val_mae: 0.5299 - val_mse: 0.5241 Epoch 64/100 128/128 - 1s - 6ms/step - loss: 0.1804 - mae: 0.3184 - mse: 0.1804 - val_loss: 0.5455 - val_mae: 0.5548 - val_mse: 0.5455 Epoch 65/100 128/128 - 1s - 6ms/step - loss: 0.1870 - mae: 0.3320 - mse: 0.1870 - val_loss: 0.5675 - val_mae: 0.5403 - val_mse: 0.5675 Epoch 66/100 128/128 - 1s - 6ms/step - loss: 0.1887 - mae: 0.3264 - mse: 0.1887 - val_loss: 0.5435 - val_mae: 0.5537 - val_mse: 0.5435 Epoch 67/100 128/128 - 1s - 6ms/step - loss: 0.1792 - mae: 0.3229 - mse: 0.1792 - val_loss: 0.6054 - val_mae: 0.5557 - val_mse: 0.6054 Epoch 68/100 128/128 - 1s - 6ms/step - loss: 0.1815 - mae: 0.3227 - mse: 0.1815 - val_loss: 0.5558 - val_mae: 0.5491 - val_mse: 0.5558 Epoch 69/100 128/128 - 1s - 6ms/step - loss: 0.1839 - mae: 0.3251 - mse: 0.1839 - val_loss: 0.5782 - val_mae: 0.5857 - val_mse: 0.5782 Epoch 70/100 128/128 - 1s - 6ms/step - loss: 0.1759 - mae: 0.3235 - mse: 0.1759 - val_loss: 0.5223 - val_mae: 0.5263 - val_mse: 0.5223 Epoch 71/100 128/128 - 1s - 6ms/step - loss: 0.1823 - mae: 0.3233 - mse: 0.1823 - val_loss: 0.5650 - val_mae: 0.5650 - val_mse: 0.5650 Epoch 72/100 128/128 - 1s - 6ms/step - loss: 0.1763 - mae: 0.3165 - mse: 0.1763 - val_loss: 0.5424 - val_mae: 0.5460 - val_mse: 0.5424 Epoch 73/100 128/128 - 1s - 6ms/step - loss: 0.1752 - mae: 0.3214 - mse: 0.1752 - val_loss: 0.6822 - val_mae: 0.6173 - val_mse: 0.6822 Epoch 74/100 128/128 - 1s - 6ms/step - loss: 0.1783 - mae: 0.3209 - mse: 0.1783 - val_loss: 0.5366 - val_mae: 0.5532 - val_mse: 0.5366 Epoch 75/100 128/128 - 1s - 6ms/step - loss: 0.1666 - mae: 0.3104 - mse: 0.1666 - val_loss: 0.5550 - val_mae: 0.5645 - val_mse: 0.5550 Epoch 76/100 128/128 - 1s - 6ms/step - loss: 0.1736 - mae: 0.3122 - mse: 0.1736 - val_loss: 0.5431 - val_mae: 0.5398 - val_mse: 0.5431 Epoch 77/100 128/128 - 1s - 6ms/step - loss: 0.1684 - mae: 0.3154 - mse: 0.1684 - val_loss: 0.5474 - val_mae: 0.5425 - val_mse: 0.5474 Epoch 78/100 128/128 - 1s - 6ms/step - loss: 0.1677 - mae: 0.3130 - mse: 0.1677 - val_loss: 0.5712 - val_mae: 0.5477 - val_mse: 0.5712 Epoch 79/100 128/128 - 1s - 6ms/step - loss: 0.1701 - mae: 0.3114 - mse: 0.1701 - val_loss: 0.5568 - val_mae: 0.5644 - val_mse: 0.5568 Epoch 80/100 128/128 - 1s - 6ms/step - loss: 0.1659 - mae: 0.3100 - mse: 0.1659 - val_loss: 0.5289 - val_mae: 0.5464 - val_mse: 0.5289 Epoch 81/100 128/128 - 1s - 6ms/step - loss: 0.1674 - mae: 0.3096 - mse: 0.1674 - val_loss: 0.5783 - val_mae: 0.5531 - val_mse: 0.5783 Epoch 82/100 128/128 - 1s - 6ms/step - loss: 0.1624 - mae: 0.3069 - mse: 0.1624 - val_loss: 0.6191 - val_mae: 0.5803 - val_mse: 0.6191 Epoch 83/100 128/128 - 1s - 6ms/step - loss: 0.1626 - mae: 0.3048 - mse: 0.1626 - val_loss: 0.5592 - val_mae: 0.5413 - val_mse: 0.5592 Epoch 84/100 128/128 - 1s - 6ms/step - loss: 0.1579 - mae: 0.3036 - mse: 0.1579 - val_loss: 0.5685 - val_mae: 0.5431 - val_mse: 0.5685 Epoch 85/100 128/128 - 1s - 6ms/step - loss: 0.1537 - mae: 0.3012 - mse: 0.1537 - val_loss: 0.5914 - val_mae: 0.5786 - val_mse: 0.5914 Epoch 86/100 128/128 - 1s - 6ms/step - loss: 0.1565 - mae: 0.2988 - mse: 0.1565 - val_loss: 0.5660 - val_mae: 0.5604 - val_mse: 0.5660 Epoch 87/100 128/128 - 1s - 6ms/step - loss: 0.1565 - mae: 0.2964 - mse: 0.1565 - val_loss: 0.5986 - val_mae: 0.5932 - val_mse: 0.5986 Epoch 88/100 128/128 - 1s - 6ms/step - loss: 0.1558 - mae: 0.3010 - mse: 0.1558 - val_loss: 0.5877 - val_mae: 0.5616 - val_mse: 0.5877 Epoch 89/100 128/128 - 1s - 6ms/step - loss: 0.1564 - mae: 0.2997 - mse: 0.1564 - val_loss: 0.5504 - val_mae: 0.5542 - val_mse: 0.5504 Epoch 90/100 128/128 - 1s - 6ms/step - loss: 0.1509 - mae: 0.2912 - mse: 0.1509 - val_loss: 0.5753 - val_mae: 0.5832 - val_mse: 0.5753 Epoch 91/100 128/128 - 1s - 6ms/step - loss: 0.1526 - mae: 0.2983 - mse: 0.1526 - val_loss: 0.6000 - val_mae: 0.5703 - val_mse: 0.6000 Epoch 92/100 128/128 - 1s - 6ms/step - loss: 0.1492 - mae: 0.2915 - mse: 0.1492 - val_loss: 0.5682 - val_mae: 0.5559 - val_mse: 0.5682 Epoch 93/100 128/128 - 1s - 6ms/step - loss: 0.1491 - mae: 0.2902 - mse: 0.1491 - val_loss: 0.5545 - val_mae: 0.5434 - val_mse: 0.5545 Epoch 94/100 128/128 - 1s - 6ms/step - loss: 0.1470 - mae: 0.2924 - mse: 0.1470 - val_loss: 0.5582 - val_mae: 0.5537 - val_mse: 0.5582 Epoch 95/100 128/128 - 1s - 6ms/step - loss: 0.1434 - mae: 0.2878 - mse: 0.1434 - val_loss: 0.5798 - val_mae: 0.5478 - val_mse: 0.5798 Epoch 96/100 128/128 - 1s - 6ms/step - loss: 0.1473 - mae: 0.2937 - mse: 0.1473 - val_loss: 0.5653 - val_mae: 0.5522 - val_mse: 0.5653 Epoch 97/100 128/128 - 1s - 6ms/step - loss: 0.1462 - mae: 0.2889 - mse: 0.1462 - val_loss: 0.6713 - val_mae: 0.5923 - val_mse: 0.6713 Epoch 98/100 128/128 - 1s - 6ms/step - loss: 0.1448 - mae: 0.2888 - mse: 0.1448 - val_loss: 0.6224 - val_mae: 0.6117 - val_mse: 0.6224 Epoch 99/100 128/128 - 1s - 6ms/step - loss: 0.1417 - mae: 0.2869 - mse: 0.1417 - val_loss: 0.6380 - val_mae: 0.5804 - val_mse: 0.6380 Epoch 100/100 128/128 - 1s - 6ms/step - loss: 0.1434 - mae: 0.2901 - mse: 0.1434 - val_loss: 0.5715 - val_mae: 0.5463 - val_mse: 0.5715
score = model.evaluate(x_test, y_test, verbose=0)
print('x_test / loss : {:5.4f}'.format(score[0]))
print('x_test / mae : {:5.4f}'.format(score[1]))
print('x_test / mse : {:5.4f}'.format(score[2]))
x_test / loss : 0.5715 x_test / mae : 0.5463 x_test / mse : 0.5715
6.2 - Training history¶
What was the best result during our training ?
print("min( val_mae ) : {:.4f}".format( min(history.history["val_mae"]) ) )
min( val_mae ) : 0.5203
fidle.scrawler.history( history, plot={'MSE' :['mse', 'val_mse'],
'MAE' :['mae', 'val_mae'],
'LOSS':['loss','val_loss']}, save_as='01-history')
Step 7 - Restore a model :¶
7.1 - Reload model¶
loaded_model = keras.models.load_model('./run/models/best_model.keras')
loaded_model.summary()
print("Loaded.")
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ Dense_n1 (Dense) │ (None, 64) │ 768 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ Dense_n2 (Dense) │ (None, 64) │ 4,160 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ Output (Dense) │ (None, 1) │ 65 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 9,988 (39.02 KB)
Trainable params: 4,993 (19.50 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 4,995 (19.51 KB)
Loaded.
7.2 - Evaluate it :¶
score = loaded_model.evaluate(x_test, y_test, verbose=0)
print('x_test / loss : {:5.4f}'.format(score[0]))
print('x_test / mae : {:5.4f}'.format(score[1]))
print('x_test / mse : {:5.4f}'.format(score[2]))
x_test / loss : 2.1294 x_test / mae : 1.1600 x_test / mse : 2.1294
7.3 - Make a prediction¶
# ---- Pick n entries from our test set
n = 200
ii = np.random.randint(1,len(x_test),n)
x_sample = x_test[ii]
y_sample = y_test[ii]
# ---- Make a predictions
y_pred = loaded_model.predict( x_sample, verbose=2 )
7/7 - 0s - 3ms/step
# ---- Show it
print('Wine Prediction Real Delta')
for i in range(n):
pred = y_pred[i][0]
real = y_sample[i]
delta = real-pred
print(f'{i:03d} {pred:.2f} {real} {delta:+.2f} ')
Wine Prediction Real Delta 000 4.52 5 +0.48 001 5.29 6 +0.71 002 4.98 6 +1.02 003 5.58 6 +0.42 004 5.93 4 -1.93 005 6.12 6 -0.12 006 5.59 5 -0.59 007 10.79 6 -4.79 008 6.52 6 -0.52 009 7.00 6 -1.00 010 6.44 5 -1.44 011 5.09 6 +0.91 012 5.02 6 +0.98 013 4.53 4 -0.53 014 4.13 5 +0.87 015 4.89 6 +1.11 016 6.90 6 -0.90 017 6.60 5 -1.60 018 6.34 6 -0.34 019 5.84 6 +0.16 020 4.52 6 +1.48 021 4.89 6 +1.11 022 5.81 4 -1.81 023 4.40 6 +1.60 024 6.46 6 -0.46 025 4.94 5 +0.06 026 5.68 5 -0.68 027 5.59 5 -0.59 028 5.60 4 -1.60 029 4.35 6 +1.65 030 5.65 5 -0.65 031 5.02 6 +0.98 032 4.69 6 +1.31 033 5.13 5 -0.13 034 4.70 6 +1.30 035 6.31 7 +0.69 036 5.48 5 -0.48 037 4.91 6 +1.09 038 5.02 6 +0.98 039 4.65 5 +0.35 040 4.91 7 +2.09 041 5.01 6 +0.99 042 5.73 6 +0.27 043 6.56 5 -1.56 044 6.08 7 +0.92 045 5.15 6 +0.85 046 3.53 6 +2.47 047 6.90 6 -0.90 048 5.39 6 +0.61 049 4.86 5 +0.14 050 5.12 7 +1.88 051 3.99 6 +2.01 052 5.26 6 +0.74 053 5.51 6 +0.49 054 5.69 5 -0.69 055 5.03 6 +0.97 056 7.29 6 -1.29 057 5.11 5 -0.11 058 4.01 6 +1.99 059 6.08 5 -1.08 060 6.22 5 -1.22 061 6.77 5 -1.77 062 4.13 5 +0.87 063 5.26 7 +1.74 064 4.35 5 +0.65 065 7.54 8 +0.46 066 5.37 6 +0.63 067 5.46 5 -0.46 068 4.02 6 +1.98 069 6.56 5 -1.56 070 5.76 5 -0.76 071 6.01 6 -0.01 072 3.78 5 +1.22 073 4.09 6 +1.91 074 6.19 5 -1.19 075 5.02 6 +0.98 076 5.09 6 +0.91 077 4.47 7 +2.53 078 4.68 5 +0.32 079 4.96 6 +1.04 080 5.01 5 -0.01 081 5.26 7 +1.74 082 4.98 6 +1.02 083 6.62 5 -1.62 084 5.41 5 -0.41 085 5.26 6 +0.74 086 6.01 7 +0.99 087 4.86 6 +1.14 088 6.77 5 -1.77 089 9.24 5 -4.24 090 6.84 5 -1.84 091 4.86 5 +0.14 092 5.75 5 -0.75 093 5.97 6 +0.03 094 6.60 5 -1.60 095 4.94 7 +2.06 096 3.68 5 +1.32 097 5.81 6 +0.19 098 4.52 5 +0.48 099 6.48 6 -0.48 100 4.06 6 +1.94 101 5.26 6 +0.74 102 6.30 7 +0.70 103 3.11 5 +1.89 104 5.21 5 -0.21 105 5.09 6 +0.91 106 6.22 5 -1.22 107 4.57 6 +1.43 108 3.82 6 +2.18 109 3.99 5 +1.01 110 5.48 5 -0.48 111 7.04 6 -1.04 112 5.02 6 +0.98 113 5.29 6 +0.71 114 6.70 5 -1.70 115 4.40 6 +1.60 116 5.11 5 -0.11 117 5.78 5 -0.78 118 5.01 5 -0.01 119 7.54 8 +0.46 120 5.26 6 +0.74 121 7.00 6 -1.00 122 5.23 6 +0.77 123 5.03 6 +0.97 124 5.59 5 -0.59 125 5.58 6 +0.42 126 6.90 6 -0.90 127 5.07 5 -0.07 128 5.79 6 +0.21 129 5.19 6 +0.81 130 6.61 4 -2.61 131 5.88 5 -0.88 132 5.37 6 +0.63 133 4.25 6 +1.75 134 4.79 6 +1.21 135 5.15 6 +0.85 136 3.80 6 +2.20 137 3.99 5 +1.01 138 5.01 6 +0.99 139 4.48 5 +0.52 140 6.01 7 +0.99 141 4.44 5 +0.56 142 6.07 5 -1.07 143 5.71 6 +0.29 144 6.64 5 -1.64 145 6.19 5 -1.19 146 6.46 6 -0.46 147 3.78 5 +1.22 148 6.77 6 -0.77 149 5.48 5 -0.48 150 3.87 6 +2.13 151 4.05 7 +2.95 152 5.84 6 +0.16 153 4.86 5 +0.14 154 5.58 6 +0.42 155 5.76 5 -0.76 156 7.30 6 -1.30 157 3.52 6 +2.48 158 6.83 7 +0.17 159 4.94 5 +0.06 160 5.89 5 -0.89 161 4.65 5 +0.35 162 5.51 6 +0.49 163 3.58 6 +2.42 164 5.13 5 -0.13 165 5.72 7 +1.28 166 7.30 6 -1.30 167 4.53 5 +0.47 168 5.02 6 +0.98 169 4.57 7 +2.43 170 3.98 5 +1.02 171 4.26 5 +0.74 172 5.50 6 +0.50 173 4.30 5 +0.70 174 5.58 6 +0.42 175 5.21 5 -0.21 176 4.14 6 +1.86 177 5.11 6 +0.89 178 5.85 6 +0.15 179 4.09 5 +0.91 180 5.93 4 -1.93 181 6.08 5 -1.08 182 5.35 6 +0.65 183 6.01 6 -0.01 184 5.20 8 +2.80 185 4.05 7 +2.95 186 5.81 4 -1.81 187 5.96 7 +1.04 188 9.90 5 -4.90 189 6.44 8 +1.56 190 6.08 7 +0.92 191 7.29 6 -1.29 192 4.68 5 +0.32 193 4.52 5 +0.48 194 5.70 5 -0.70 195 4.73 5 +0.27 196 5.74 7 +1.26 197 5.89 5 -0.89 198 4.75 6 +1.25 199 6.22 5 -1.22
Few questions :¶
- Can this model be used for red wines from Bordeaux and/or Beaujolais?
- What are the limitations of this model?
- What are the limitations of this dataset?
fidle.end()
End time : 03/03/24 21:05:54
Duration : 00:01:20 827ms
This notebook ends here :-)
https://fidle.cnrs.fr