[K3WINE1] - Wine quality prediction with a Dense Network (DNN)¶
Another example of regression, with a wine quality prediction, using Keras 3 and PyTorchObjectives :¶
- Predict the quality of wines, based on their analysis
- Understanding the principle and the architecture of a regression with a dense neural network with backup and restore of the trained model.
The Wine Quality datasets are made up of analyses of a large number of wines, with an associated quality (between 0 and 10)
This dataset is provide by :
Paulo Cortez, University of Minho, Guimarães, Portugal, http://www3.dsi.uminho.pt/pcortez
A. Cerdeira, F. Almeida, T. Matos and J. Reis, Viticulture Commission of the Vinho Verde Region(CVRVV), Porto, Portugal, @2009
This dataset can be retreive at University of California Irvine (UCI)
Due to privacy and logistic issues, only physicochemical and sensory variables are available
There is no data about grape types, wine brand, wine selling price, etc.
- fixed acidity
- volatile acidity
- citric acid
- residual sugar
- chlorides
- free sulfur dioxide
- total sulfur dioxide
- density
- pH
- sulphates
- alcohol
- quality (score between 0 and 10)
What we're going to do :¶
- (Retrieve data)
- (Preparing the data)
- (Build a model)
- Train and save the model
- Restore saved model
- Evaluate the model
- Make some predictions
Step 1 - Import and init¶
import os
os.environ['KERAS_BACKEND'] = 'torch'
import keras
import numpy as np
import pandas as pd
import fidle
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3WINE1')
FIDLE - Environment initialization
Version : 2.3.2 Run id : K3WINE1 Run dir : ./run/K3WINE1 Datasets dir : /lustre/fswork/projects/rech/mlh/uja62cb/fidle-project/datasets-fidle Start time : 22/12/24 21:21:04 Hostname : r3i7n1 (Linux) Tensorflow log level : Info + Warning + Error (=0) Update keras cache : False Update torch cache : False Save figs : ./run/K3WINE1/figs (True) keras : 3.7.0 numpy : 2.1.2 sklearn : 1.5.2 yaml : 6.0.2 matplotlib : 3.9.2 pandas : 2.2.3 torch : 2.5.0
Verbosity during training :
- 0 = silent
- 1 = progress bar
- 2 = one line per epoch
fit_verbosity = 1
dataset_name = 'winequality-red.csv'
Override parameters (batch mode) - Just forget this cell
fidle.override('fit_verbosity', 'dataset_name')
** Overrided parameters : ** fit_verbosity : 2
Step 2 - Retrieve data¶
data = pd.read_csv(f'{datasets_dir}/WineQuality/origine/{dataset_name}', header=0,sep=';')
display(data.head(5).style.format("{0:.2f}"))
print('Missing Data : ',data.isna().sum().sum(), ' Shape is : ', data.shape)
fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | quality | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 7.40 | 0.70 | 0.00 | 1.90 | 0.08 | 11.00 | 34.00 | 1.00 | 3.51 | 0.56 | 9.40 | 5.00 |
1 | 7.80 | 0.88 | 0.00 | 2.60 | 0.10 | 25.00 | 67.00 | 1.00 | 3.20 | 0.68 | 9.80 | 5.00 |
2 | 7.80 | 0.76 | 0.04 | 2.30 | 0.09 | 15.00 | 54.00 | 1.00 | 3.26 | 0.65 | 9.80 | 5.00 |
3 | 11.20 | 0.28 | 0.56 | 1.90 | 0.07 | 17.00 | 60.00 | 1.00 | 3.16 | 0.58 | 9.80 | 6.00 |
4 | 7.40 | 0.70 | 0.00 | 1.90 | 0.08 | 11.00 | 34.00 | 1.00 | 3.51 | 0.56 | 9.40 | 5.00 |
Missing Data : 0 Shape is : (1599, 12)
# ---- Split => train, test
#
data = data.sample(frac=1., axis=0) # Shuffle
data_train = data.sample(frac=0.8, axis=0) # get 80 %
data_test = data.drop(data_train.index) # test = all - train
# ---- Split => x,y (medv is price)
#
x_train = data_train.drop('quality', axis=1)
y_train = data_train['quality']
x_test = data_test.drop('quality', axis=1)
y_test = data_test['quality']
print('Original data shape was : ',data.shape)
print('x_train : ',x_train.shape, 'y_train : ',y_train.shape)
print('x_test : ',x_test.shape, 'y_test : ',y_test.shape)
Original data shape was : (1599, 12) x_train : (1279, 11) y_train : (1279,) x_test : (320, 11) y_test : (320,)
3.2 - Data normalization¶
Note :
- All input data must be normalized, train and test.
- To do this we will subtract the mean and divide by the standard deviation.
- But test data should not be used in any way, even for normalization.
- The mean and the standard deviation will therefore only be calculated with the train data.
display(x_train.describe().style.format("{0:.2f}").set_caption("Before normalization :"))
mean = x_train.mean()
std = x_train.std()
x_train = (x_train - mean) / std
x_test = (x_test - mean) / std
display(x_train.describe().style.format("{0:.2f}").set_caption("After normalization :"))
# Convert ou DataFrame to numpy array
x_train, y_train = np.array(x_train), np.array(y_train)
x_test, y_test = np.array(x_test), np.array(y_test)
fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | |
---|---|---|---|---|---|---|---|---|---|---|---|
count | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 |
mean | 8.32 | 0.53 | 0.27 | 2.53 | 0.09 | 15.99 | 46.53 | 1.00 | 3.31 | 0.66 | 10.41 |
std | 1.75 | 0.18 | 0.20 | 1.40 | 0.04 | 10.37 | 32.83 | 0.00 | 0.15 | 0.17 | 1.06 |
min | 4.60 | 0.12 | 0.00 | 0.90 | 0.01 | 1.00 | 6.00 | 0.99 | 2.74 | 0.33 | 8.40 |
25% | 7.10 | 0.40 | 0.09 | 1.90 | 0.07 | 8.00 | 22.50 | 1.00 | 3.21 | 0.55 | 9.50 |
50% | 7.90 | 0.52 | 0.25 | 2.20 | 0.08 | 14.00 | 38.00 | 1.00 | 3.31 | 0.62 | 10.10 |
75% | 9.30 | 0.64 | 0.43 | 2.60 | 0.09 | 21.00 | 61.50 | 1.00 | 3.40 | 0.73 | 11.08 |
max | 15.60 | 1.33 | 1.00 | 15.50 | 0.61 | 68.00 | 289.00 | 1.00 | 4.01 | 2.00 | 14.00 |
fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | |
---|---|---|---|---|---|---|---|---|---|---|---|
count | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 | 1279.00 |
mean | 0.00 | -0.00 | -0.00 | 0.00 | -0.00 | 0.00 | -0.00 | -0.00 | -0.00 | 0.00 | 0.00 |
std | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
min | -2.13 | -2.32 | -1.37 | -1.17 | -1.70 | -1.45 | -1.23 | -3.50 | -3.72 | -1.96 | -1.89 |
25% | -0.70 | -0.73 | -0.91 | -0.45 | -0.38 | -0.77 | -0.73 | -0.61 | -0.66 | -0.64 | -0.85 |
50% | -0.24 | -0.05 | -0.09 | -0.24 | -0.18 | -0.19 | -0.26 | 0.02 | -0.00 | -0.23 | -0.29 |
75% | 0.56 | 0.62 | 0.83 | 0.05 | 0.07 | 0.48 | 0.46 | 0.56 | 0.58 | 0.43 | 0.64 |
max | 4.16 | 4.53 | 3.74 | 9.27 | 11.89 | 5.02 | 7.39 | 3.63 | 4.56 | 8.01 | 3.38 |
def get_model_v1(shape):
model = keras.models.Sequential()
model.add(keras.layers.Input(shape, name="InputLayer"))
model.add(keras.layers.Dense(64, activation='relu', name='Dense_n1'))
model.add(keras.layers.Dense(64, activation='relu', name='Dense_n2'))
model.add(keras.layers.Dense(1, name='Output'))
model.compile(optimizer = 'rmsprop',
loss = 'mse',
metrics = ['mae', 'mse'] )
return model
model=get_model_v1( (11,) )
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ Dense_n1 (Dense) │ (None, 64) │ 768 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ Dense_n2 (Dense) │ (None, 64) │ 4,160 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ Output (Dense) │ (None, 1) │ 65 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 4,993 (19.50 KB)
Trainable params: 4,993 (19.50 KB)
Non-trainable params: 0 (0.00 B)
5.2 - Add callback¶
os.makedirs('./run/models', mode=0o750, exist_ok=True)
save_dir = "./run/models/best_model.keras"
savemodel_callback = keras.callbacks.ModelCheckpoint( filepath=save_dir, monitor='val_mae', mode='max', save_best_only=True)
5.3 - Train it¶
history = model.fit(x_train,
y_train,
epochs = 100,
batch_size = 10,
verbose = fit_verbosity,
validation_data = (x_test, y_test),
callbacks = [savemodel_callback])
Epoch 1/100
128/128 - 2s - 15ms/step - loss: 6.8173 - mae: 2.0394 - mse: 6.8173 - val_loss: 1.9679 - val_mae: 1.0980 - val_mse: 1.9679
Epoch 2/100
128/128 - 1s - 6ms/step - loss: 1.4374 - mae: 0.9350 - mse: 1.4374 - val_loss: 1.1838 - val_mae: 0.8539 - val_mse: 1.1838
Epoch 3/100
128/128 - 1s - 6ms/step - loss: 0.9545 - mae: 0.7510 - mse: 0.9545 - val_loss: 0.8727 - val_mae: 0.7226 - val_mse: 0.8727
Epoch 4/100
128/128 - 1s - 6ms/step - loss: 0.7503 - mae: 0.6711 - mse: 0.7503 - val_loss: 0.6673 - val_mae: 0.6248 - val_mse: 0.6673
Epoch 5/100
128/128 - 1s - 6ms/step - loss: 0.6203 - mae: 0.6117 - mse: 0.6203 - val_loss: 0.5975 - val_mae: 0.5857 - val_mse: 0.5975
Epoch 6/100
128/128 - 1s - 6ms/step - loss: 0.5537 - mae: 0.5811 - mse: 0.5537 - val_loss: 0.6115 - val_mae: 0.5938 - val_mse: 0.6115
Epoch 7/100
128/128 - 1s - 6ms/step - loss: 0.5130 - mae: 0.5558 - mse: 0.5130 - val_loss: 0.5651 - val_mae: 0.5351 - val_mse: 0.5651
Epoch 8/100
128/128 - 1s - 6ms/step - loss: 0.4784 - mae: 0.5335 - mse: 0.4784 - val_loss: 0.4816 - val_mae: 0.5291 - val_mse: 0.4816
Epoch 9/100
128/128 - 1s - 6ms/step - loss: 0.4559 - mae: 0.5207 - mse: 0.4559 - val_loss: 0.4662 - val_mae: 0.5036 - val_mse: 0.4662
Epoch 10/100
128/128 - 1s - 6ms/step - loss: 0.4329 - mae: 0.5104 - mse: 0.4329 - val_loss: 0.4895 - val_mae: 0.5207 - val_mse: 0.4895
Epoch 11/100
128/128 - 1s - 6ms/step - loss: 0.4247 - mae: 0.5055 - mse: 0.4247 - val_loss: 0.4828 - val_mae: 0.5251 - val_mse: 0.4828
Epoch 12/100
128/128 - 1s - 6ms/step - loss: 0.4145 - mae: 0.4992 - mse: 0.4145 - val_loss: 0.4120 - val_mae: 0.4785 - val_mse: 0.4120
Epoch 13/100
128/128 - 1s - 6ms/step - loss: 0.3995 - mae: 0.4896 - mse: 0.3995 - val_loss: 0.4190 - val_mae: 0.4578 - val_mse: 0.4190
Epoch 14/100
128/128 - 1s - 6ms/step - loss: 0.3861 - mae: 0.4785 - mse: 0.3861 - val_loss: 0.3771 - val_mae: 0.4577 - val_mse: 0.3771
Epoch 15/100
128/128 - 1s - 6ms/step - loss: 0.3738 - mae: 0.4771 - mse: 0.3738 - val_loss: 0.4346 - val_mae: 0.4904 - val_mse: 0.4346
Epoch 16/100
128/128 - 1s - 6ms/step - loss: 0.3741 - mae: 0.4783 - mse: 0.3741 - val_loss: 0.3868 - val_mae: 0.4629 - val_mse: 0.3868
Epoch 17/100
128/128 - 1s - 6ms/step - loss: 0.3741 - mae: 0.4761 - mse: 0.3741 - val_loss: 0.3960 - val_mae: 0.4461 - val_mse: 0.3960
Epoch 18/100
128/128 - 1s - 6ms/step - loss: 0.3658 - mae: 0.4689 - mse: 0.3658 - val_loss: 0.4617 - val_mae: 0.4880 - val_mse: 0.4617
Epoch 19/100
128/128 - 1s - 6ms/step - loss: 0.3604 - mae: 0.4656 - mse: 0.3604 - val_loss: 0.4149 - val_mae: 0.4654 - val_mse: 0.4149
Epoch 20/100
128/128 - 1s - 6ms/step - loss: 0.3594 - mae: 0.4640 - mse: 0.3594 - val_loss: 0.3704 - val_mae: 0.4494 - val_mse: 0.3704
Epoch 21/100
128/128 - 1s - 6ms/step - loss: 0.3488 - mae: 0.4551 - mse: 0.3488 - val_loss: 0.4119 - val_mae: 0.4660 - val_mse: 0.4119
Epoch 22/100
128/128 - 1s - 6ms/step - loss: 0.3456 - mae: 0.4592 - mse: 0.3456 - val_loss: 0.4034 - val_mae: 0.4762 - val_mse: 0.4034
Epoch 23/100
128/128 - 1s - 6ms/step - loss: 0.3465 - mae: 0.4542 - mse: 0.3465 - val_loss: 0.3710 - val_mae: 0.4507 - val_mse: 0.3710
Epoch 24/100
128/128 - 1s - 6ms/step - loss: 0.3395 - mae: 0.4544 - mse: 0.3395 - val_loss: 0.3651 - val_mae: 0.4417 - val_mse: 0.3651
Epoch 25/100
128/128 - 1s - 6ms/step - loss: 0.3274 - mae: 0.4429 - mse: 0.3274 - val_loss: 0.4090 - val_mae: 0.4787 - val_mse: 0.4090
Epoch 26/100
128/128 - 1s - 6ms/step - loss: 0.3286 - mae: 0.4429 - mse: 0.3286 - val_loss: 0.4292 - val_mae: 0.4845 - val_mse: 0.4292
Epoch 27/100
128/128 - 1s - 6ms/step - loss: 0.3240 - mae: 0.4444 - mse: 0.3240 - val_loss: 0.4364 - val_mae: 0.4856 - val_mse: 0.4364
Epoch 28/100
128/128 - 1s - 6ms/step - loss: 0.3285 - mae: 0.4426 - mse: 0.3285 - val_loss: 0.3805 - val_mae: 0.4399 - val_mse: 0.3805
Epoch 29/100
128/128 - 1s - 6ms/step - loss: 0.3186 - mae: 0.4356 - mse: 0.3186 - val_loss: 0.4233 - val_mae: 0.4670 - val_mse: 0.4233
Epoch 30/100
128/128 - 1s - 6ms/step - loss: 0.3153 - mae: 0.4366 - mse: 0.3153 - val_loss: 0.3824 - val_mae: 0.4479 - val_mse: 0.3824
Epoch 31/100
128/128 - 1s - 6ms/step - loss: 0.3074 - mae: 0.4264 - mse: 0.3074 - val_loss: 0.3639 - val_mae: 0.4584 - val_mse: 0.3639
Epoch 32/100
128/128 - 1s - 6ms/step - loss: 0.3105 - mae: 0.4290 - mse: 0.3105 - val_loss: 0.4044 - val_mae: 0.4458 - val_mse: 0.4044
Epoch 33/100
128/128 - 1s - 6ms/step - loss: 0.3018 - mae: 0.4239 - mse: 0.3018 - val_loss: 0.3586 - val_mae: 0.4287 - val_mse: 0.3586
Epoch 34/100
128/128 - 1s - 6ms/step - loss: 0.3060 - mae: 0.4270 - mse: 0.3060 - val_loss: 0.4015 - val_mae: 0.4573 - val_mse: 0.4015
Epoch 35/100
128/128 - 1s - 6ms/step - loss: 0.2957 - mae: 0.4171 - mse: 0.2957 - val_loss: 0.3620 - val_mae: 0.4343 - val_mse: 0.3620
Epoch 36/100
128/128 - 1s - 6ms/step - loss: 0.2937 - mae: 0.4190 - mse: 0.2937 - val_loss: 0.3391 - val_mae: 0.4503 - val_mse: 0.3391
Epoch 37/100
128/128 - 1s - 6ms/step - loss: 0.2857 - mae: 0.4143 - mse: 0.2857 - val_loss: 0.4127 - val_mae: 0.4833 - val_mse: 0.4127
Epoch 38/100
128/128 - 1s - 6ms/step - loss: 0.2910 - mae: 0.4161 - mse: 0.2910 - val_loss: 0.3518 - val_mae: 0.4309 - val_mse: 0.3518
Epoch 39/100
128/128 - 1s - 6ms/step - loss: 0.2856 - mae: 0.4111 - mse: 0.2856 - val_loss: 0.4506 - val_mae: 0.4904 - val_mse: 0.4506
Epoch 40/100
128/128 - 1s - 6ms/step - loss: 0.2890 - mae: 0.4137 - mse: 0.2890 - val_loss: 0.3379 - val_mae: 0.4264 - val_mse: 0.3379
Epoch 41/100
128/128 - 1s - 6ms/step - loss: 0.2873 - mae: 0.4099 - mse: 0.2873 - val_loss: 0.4576 - val_mae: 0.5278 - val_mse: 0.4576
Epoch 42/100
128/128 - 1s - 6ms/step - loss: 0.2786 - mae: 0.4053 - mse: 0.2786 - val_loss: 0.3307 - val_mae: 0.4310 - val_mse: 0.3307
Epoch 43/100
128/128 - 1s - 6ms/step - loss: 0.2779 - mae: 0.4049 - mse: 0.2779 - val_loss: 0.4032 - val_mae: 0.4793 - val_mse: 0.4032
Epoch 44/100
128/128 - 1s - 6ms/step - loss: 0.2727 - mae: 0.4040 - mse: 0.2727 - val_loss: 0.5626 - val_mae: 0.5710 - val_mse: 0.5626
Epoch 45/100
128/128 - 1s - 6ms/step - loss: 0.2721 - mae: 0.3983 - mse: 0.2721 - val_loss: 0.3801 - val_mae: 0.4469 - val_mse: 0.3801
Epoch 46/100
128/128 - 1s - 6ms/step - loss: 0.2646 - mae: 0.3983 - mse: 0.2646 - val_loss: 0.3432 - val_mae: 0.4357 - val_mse: 0.3432
Epoch 47/100
128/128 - 1s - 6ms/step - loss: 0.2597 - mae: 0.3924 - mse: 0.2597 - val_loss: 0.3822 - val_mae: 0.4718 - val_mse: 0.3822
Epoch 48/100
128/128 - 1s - 6ms/step - loss: 0.2617 - mae: 0.3945 - mse: 0.2617 - val_loss: 0.3742 - val_mae: 0.4457 - val_mse: 0.3742
Epoch 49/100
128/128 - 1s - 6ms/step - loss: 0.2642 - mae: 0.3968 - mse: 0.2642 - val_loss: 0.3297 - val_mae: 0.4337 - val_mse: 0.3297
Epoch 50/100
128/128 - 1s - 6ms/step - loss: 0.2550 - mae: 0.3902 - mse: 0.2550 - val_loss: 0.4676 - val_mae: 0.5114 - val_mse: 0.4676
Epoch 51/100
128/128 - 1s - 6ms/step - loss: 0.2556 - mae: 0.3877 - mse: 0.2556 - val_loss: 0.4017 - val_mae: 0.4779 - val_mse: 0.4017
Epoch 52/100
128/128 - 1s - 6ms/step - loss: 0.2457 - mae: 0.3758 - mse: 0.2457 - val_loss: 0.4207 - val_mae: 0.4899 - val_mse: 0.4207
Epoch 53/100
128/128 - 1s - 6ms/step - loss: 0.2435 - mae: 0.3829 - mse: 0.2435 - val_loss: 0.3572 - val_mae: 0.4451 - val_mse: 0.3572
Epoch 54/100
128/128 - 1s - 6ms/step - loss: 0.2451 - mae: 0.3793 - mse: 0.2451 - val_loss: 0.3653 - val_mae: 0.4334 - val_mse: 0.3653
Epoch 55/100
128/128 - 1s - 6ms/step - loss: 0.2447 - mae: 0.3845 - mse: 0.2447 - val_loss: 0.3717 - val_mae: 0.4492 - val_mse: 0.3717
Epoch 56/100
128/128 - 1s - 6ms/step - loss: 0.2415 - mae: 0.3750 - mse: 0.2415 - val_loss: 0.3347 - val_mae: 0.4266 - val_mse: 0.3347
Epoch 57/100
128/128 - 1s - 6ms/step - loss: 0.2397 - mae: 0.3705 - mse: 0.2397 - val_loss: 0.3904 - val_mae: 0.4633 - val_mse: 0.3904
Epoch 58/100
128/128 - 1s - 6ms/step - loss: 0.2397 - mae: 0.3725 - mse: 0.2397 - val_loss: 0.3644 - val_mae: 0.4475 - val_mse: 0.3644
Epoch 59/100
128/128 - 1s - 6ms/step - loss: 0.2337 - mae: 0.3678 - mse: 0.2337 - val_loss: 0.3811 - val_mae: 0.4571 - val_mse: 0.3811
Epoch 60/100
128/128 - 1s - 6ms/step - loss: 0.2345 - mae: 0.3688 - mse: 0.2345 - val_loss: 0.3555 - val_mae: 0.4357 - val_mse: 0.3555
Epoch 61/100
128/128 - 1s - 6ms/step - loss: 0.2340 - mae: 0.3661 - mse: 0.2340 - val_loss: 0.3472 - val_mae: 0.4334 - val_mse: 0.3472
Epoch 62/100
128/128 - 1s - 6ms/step - loss: 0.2252 - mae: 0.3655 - mse: 0.2252 - val_loss: 0.3727 - val_mae: 0.4434 - val_mse: 0.3727
Epoch 63/100
128/128 - 1s - 6ms/step - loss: 0.2231 - mae: 0.3633 - mse: 0.2231 - val_loss: 0.3912 - val_mae: 0.4560 - val_mse: 0.3912
Epoch 64/100
128/128 - 1s - 6ms/step - loss: 0.2166 - mae: 0.3561 - mse: 0.2166 - val_loss: 0.4275 - val_mae: 0.4865 - val_mse: 0.4275
Epoch 65/100
128/128 - 1s - 6ms/step - loss: 0.2221 - mae: 0.3623 - mse: 0.2221 - val_loss: 0.4162 - val_mae: 0.4753 - val_mse: 0.4162
Epoch 66/100
128/128 - 1s - 6ms/step - loss: 0.2175 - mae: 0.3576 - mse: 0.2175 - val_loss: 0.3607 - val_mae: 0.4486 - val_mse: 0.3607
Epoch 67/100
128/128 - 1s - 6ms/step - loss: 0.2155 - mae: 0.3500 - mse: 0.2155 - val_loss: 0.3468 - val_mae: 0.4497 - val_mse: 0.3468
Epoch 68/100
128/128 - 1s - 6ms/step - loss: 0.2125 - mae: 0.3547 - mse: 0.2125 - val_loss: 0.3396 - val_mae: 0.4390 - val_mse: 0.3396
Epoch 69/100
128/128 - 1s - 6ms/step - loss: 0.2081 - mae: 0.3501 - mse: 0.2081 - val_loss: 0.3625 - val_mae: 0.4393 - val_mse: 0.3625
Epoch 70/100
128/128 - 1s - 6ms/step - loss: 0.2127 - mae: 0.3516 - mse: 0.2127 - val_loss: 0.3838 - val_mae: 0.4675 - val_mse: 0.3838
Epoch 71/100
128/128 - 1s - 6ms/step - loss: 0.2100 - mae: 0.3485 - mse: 0.2100 - val_loss: 0.3667 - val_mae: 0.4441 - val_mse: 0.3667
Epoch 72/100
128/128 - 1s - 6ms/step - loss: 0.2116 - mae: 0.3493 - mse: 0.2116 - val_loss: 0.3862 - val_mae: 0.4702 - val_mse: 0.3862
Epoch 73/100
128/128 - 1s - 6ms/step - loss: 0.2043 - mae: 0.3441 - mse: 0.2043 - val_loss: 0.4377 - val_mae: 0.4685 - val_mse: 0.4377
Epoch 74/100
128/128 - 1s - 6ms/step - loss: 0.2030 - mae: 0.3488 - mse: 0.2030 - val_loss: 0.3708 - val_mae: 0.4605 - val_mse: 0.3708
Epoch 75/100
128/128 - 1s - 6ms/step - loss: 0.1985 - mae: 0.3416 - mse: 0.1985 - val_loss: 0.3834 - val_mae: 0.4581 - val_mse: 0.3834
Epoch 76/100
128/128 - 1s - 6ms/step - loss: 0.1971 - mae: 0.3394 - mse: 0.1971 - val_loss: 0.3705 - val_mae: 0.4519 - val_mse: 0.3705
Epoch 77/100
128/128 - 1s - 6ms/step - loss: 0.1939 - mae: 0.3357 - mse: 0.1939 - val_loss: 0.3670 - val_mae: 0.4456 - val_mse: 0.3670
Epoch 78/100
128/128 - 1s - 6ms/step - loss: 0.1943 - mae: 0.3388 - mse: 0.1943 - val_loss: 0.4253 - val_mae: 0.4705 - val_mse: 0.4253
Epoch 79/100
128/128 - 1s - 6ms/step - loss: 0.1937 - mae: 0.3384 - mse: 0.1937 - val_loss: 0.3619 - val_mae: 0.4492 - val_mse: 0.3619
Epoch 80/100
128/128 - 1s - 6ms/step - loss: 0.1907 - mae: 0.3348 - mse: 0.1907 - val_loss: 0.3816 - val_mae: 0.4559 - val_mse: 0.3816
Epoch 81/100
128/128 - 1s - 6ms/step - loss: 0.1915 - mae: 0.3320 - mse: 0.1915 - val_loss: 0.3805 - val_mae: 0.4703 - val_mse: 0.3805
Epoch 82/100
128/128 - 1s - 6ms/step - loss: 0.1831 - mae: 0.3282 - mse: 0.1831 - val_loss: 0.3842 - val_mae: 0.4491 - val_mse: 0.3842
Epoch 83/100
128/128 - 1s - 6ms/step - loss: 0.1890 - mae: 0.3289 - mse: 0.1890 - val_loss: 0.3667 - val_mae: 0.4613 - val_mse: 0.3667
Epoch 84/100
128/128 - 1s - 6ms/step - loss: 0.1906 - mae: 0.3304 - mse: 0.1906 - val_loss: 0.3851 - val_mae: 0.4581 - val_mse: 0.3851
Epoch 85/100
128/128 - 1s - 6ms/step - loss: 0.1837 - mae: 0.3299 - mse: 0.1837 - val_loss: 0.3575 - val_mae: 0.4390 - val_mse: 0.3575
Epoch 86/100
128/128 - 1s - 6ms/step - loss: 0.1823 - mae: 0.3254 - mse: 0.1823 - val_loss: 0.5008 - val_mae: 0.5252 - val_mse: 0.5008
Epoch 87/100
128/128 - 1s - 6ms/step - loss: 0.1819 - mae: 0.3271 - mse: 0.1819 - val_loss: 0.3854 - val_mae: 0.4472 - val_mse: 0.3854
Epoch 88/100
128/128 - 1s - 6ms/step - loss: 0.1794 - mae: 0.3243 - mse: 0.1794 - val_loss: 0.3877 - val_mae: 0.4655 - val_mse: 0.3877
Epoch 89/100
128/128 - 1s - 6ms/step - loss: 0.1784 - mae: 0.3234 - mse: 0.1784 - val_loss: 0.3718 - val_mae: 0.4371 - val_mse: 0.3718
Epoch 90/100
128/128 - 1s - 6ms/step - loss: 0.1781 - mae: 0.3164 - mse: 0.1781 - val_loss: 0.4523 - val_mae: 0.5051 - val_mse: 0.4523
Epoch 91/100
128/128 - 1s - 6ms/step - loss: 0.1727 - mae: 0.3151 - mse: 0.1727 - val_loss: 0.3604 - val_mae: 0.4453 - val_mse: 0.3604
Epoch 92/100
128/128 - 1s - 6ms/step - loss: 0.1778 - mae: 0.3246 - mse: 0.1778 - val_loss: 0.3721 - val_mae: 0.4566 - val_mse: 0.3721
Epoch 93/100
128/128 - 1s - 6ms/step - loss: 0.1728 - mae: 0.3161 - mse: 0.1728 - val_loss: 0.4423 - val_mae: 0.5112 - val_mse: 0.4423
Epoch 94/100
128/128 - 1s - 6ms/step - loss: 0.1706 - mae: 0.3148 - mse: 0.1706 - val_loss: 0.3600 - val_mae: 0.4439 - val_mse: 0.3600
Epoch 95/100
128/128 - 1s - 6ms/step - loss: 0.1730 - mae: 0.3154 - mse: 0.1730 - val_loss: 0.3840 - val_mae: 0.4562 - val_mse: 0.3840
Epoch 96/100
128/128 - 1s - 6ms/step - loss: 0.1685 - mae: 0.3137 - mse: 0.1685 - val_loss: 0.4091 - val_mae: 0.4694 - val_mse: 0.4091
Epoch 97/100
128/128 - 1s - 6ms/step - loss: 0.1640 - mae: 0.3100 - mse: 0.1640 - val_loss: 0.4209 - val_mae: 0.4908 - val_mse: 0.4209
Epoch 98/100
128/128 - 1s - 6ms/step - loss: 0.1687 - mae: 0.3131 - mse: 0.1687 - val_loss: 0.4084 - val_mae: 0.4642 - val_mse: 0.4084
Epoch 99/100
128/128 - 1s - 6ms/step - loss: 0.1638 - mae: 0.3064 - mse: 0.1638 - val_loss: 0.4698 - val_mae: 0.5076 - val_mse: 0.4698
Epoch 100/100
128/128 - 1s - 6ms/step - loss: 0.1634 - mae: 0.3053 - mse: 0.1634 - val_loss: 0.3815 - val_mae: 0.4573 - val_mse: 0.3815
score = model.evaluate(x_test, y_test, verbose=0)
print('x_test / loss : {:5.4f}'.format(score[0]))
print('x_test / mae : {:5.4f}'.format(score[1]))
print('x_test / mse : {:5.4f}'.format(score[2]))
x_test / loss : 0.3815 x_test / mae : 0.4573 x_test / mse : 0.3815
6.2 - Training history¶
What was the best result during our training ?
print("min( val_mae ) : {:.4f}".format( min(history.history["val_mae"]) ) )
min( val_mae ) : 0.4264
fidle.scrawler.history( history, plot={'MSE' :['mse', 'val_mse'],
'MAE' :['mae', 'val_mae'],
'LOSS':['loss','val_loss']}, save_as='01-history')
Step 7 - Restore a model :¶
7.1 - Reload model¶
loaded_model = keras.models.load_model('./run/models/best_model.keras')
loaded_model.summary()
print("Loaded.")
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ Dense_n1 (Dense) │ (None, 64) │ 768 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ Dense_n2 (Dense) │ (None, 64) │ 4,160 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ Output (Dense) │ (None, 1) │ 65 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 9,988 (39.02 KB)
Trainable params: 4,993 (19.50 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 4,995 (19.51 KB)
Loaded.
7.2 - Evaluate it :¶
score = loaded_model.evaluate(x_test, y_test, verbose=0)
print('x_test / loss : {:5.4f}'.format(score[0]))
print('x_test / mae : {:5.4f}'.format(score[1]))
print('x_test / mse : {:5.4f}'.format(score[2]))
x_test / loss : 1.9679 x_test / mae : 1.0980 x_test / mse : 1.9679
7.3 - Make a prediction¶
# ---- Pick n entries from our test set
n = 200
ii = np.random.randint(1,len(x_test),n)
x_sample = x_test[ii]
y_sample = y_test[ii]
# ---- Make a predictions
y_pred = loaded_model.predict( x_sample, verbose=2 )
7/7 - 0s - 2ms/step
# ---- Show it
print('Wine Prediction Real Delta')
for i in range(n):
pred = y_pred[i][0]
real = y_sample[i]
delta = real-pred
print(f'{i:03d} {pred:.2f} {real} {delta:+.2f} ')
Wine Prediction Real Delta 000 4.33 5 +0.67 001 4.27 5 +0.73 002 3.99 5 +1.01 003 4.64 6 +1.36 004 5.46 5 -0.46 005 5.64 6 +0.36 006 2.72 5 +2.28 007 4.65 5 +0.35 008 4.06 6 +1.94 009 8.13 5 -3.13 010 6.54 6 -0.54 011 6.59 5 -1.59 012 4.28 6 +1.72 013 4.86 6 +1.14 014 3.91 5 +1.09 015 4.74 6 +1.26 016 5.51 6 +0.49 017 5.56 7 +1.44 018 6.47 6 -0.47 019 9.33 4 -5.33 020 6.70 5 -1.70 021 5.39 5 -0.39 022 5.57 7 +1.43 023 5.49 6 +0.51 024 5.80 6 +0.20 025 5.09 5 -0.09 026 4.55 5 +0.45 027 4.00 6 +2.00 028 4.00 6 +2.00 029 7.23 5 -2.23 030 5.68 7 +1.32 031 5.20 5 -0.20 032 4.59 5 +0.41 033 6.11 4 -2.11 034 4.74 6 +1.26 035 5.24 6 +0.76 036 6.76 5 -1.76 037 5.19 5 -0.19 038 5.53 5 -0.53 039 6.74 7 +0.26 040 5.03 5 -0.03 041 5.64 6 +0.36 042 5.96 5 -0.96 043 4.77 5 +0.23 044 6.60 7 +0.40 045 6.82 6 -0.82 046 4.30 7 +2.70 047 4.33 5 +0.67 048 4.60 5 +0.40 049 5.93 6 +0.07 050 5.39 5 -0.39 051 5.03 5 -0.03 052 4.33 6 +1.67 053 4.00 5 +1.00 054 5.03 5 -0.03 055 5.68 7 +1.32 056 3.79 6 +2.21 057 4.23 7 +2.77 058 5.81 6 +0.19 059 4.78 6 +1.22 060 6.32 5 -1.32 061 3.93 5 +1.07 062 7.54 5 -2.54 063 5.03 6 +0.97 064 3.10 5 +1.90 065 4.41 6 +1.59 066 5.24 6 +0.76 067 7.23 5 -2.23 068 5.50 6 +0.50 069 7.23 7 -0.23 070 6.60 7 +0.40 071 7.80 5 -2.80 072 3.80 6 +2.20 073 5.48 6 +0.52 074 7.15 5 -2.15 075 6.02 6 -0.02 076 5.04 6 +0.96 077 4.27 5 +0.73 078 4.28 5 +0.72 079 5.77 5 -0.77 080 5.49 7 +1.51 081 3.79 6 +2.21 082 5.03 5 -0.03 083 4.93 6 +1.07 084 3.99 5 +1.01 085 4.14 6 +1.86 086 6.19 6 -0.19 087 6.73 6 -0.73 088 4.42 5 +0.58 089 5.24 6 +0.76 090 3.79 5 +1.21 091 5.87 5 -0.87 092 4.83 6 +1.17 093 6.52 7 +0.48 094 4.64 7 +2.36 095 3.53 5 +1.47 096 4.88 5 +0.12 097 3.05 6 +2.95 098 4.00 6 +2.00 099 6.76 6 -0.76 100 6.19 5 -1.19 101 6.59 5 -1.59 102 4.77 6 +1.23 103 5.96 5 -0.96 104 6.76 6 -0.76 105 7.40 6 -1.40 106 4.77 6 +1.23 107 7.54 5 -2.54 108 6.59 5 -1.59 109 4.41 5 +0.59 110 7.40 6 -1.40 111 5.95 4 -1.95 112 6.39 7 +0.61 113 4.60 5 +0.40 114 6.82 6 -0.82 115 3.95 5 +1.05 116 5.52 6 +0.48 117 4.73 4 -0.73 118 4.86 6 +1.14 119 4.27 5 +0.73 120 3.95 5 +1.05 121 5.91 6 +0.09 122 7.23 7 -0.23 123 4.41 6 +1.59 124 5.81 8 +2.19 125 6.19 7 +0.81 126 7.17 5 -2.17 127 4.17 5 +0.83 128 4.93 5 +0.07 129 5.35 8 +2.65 130 4.64 5 +0.36 131 3.74 6 +2.26 132 4.64 5 +0.36 133 6.76 6 -0.76 134 4.27 5 +0.73 135 5.81 6 +0.19 136 4.27 6 +1.73 137 6.94 6 -0.94 138 5.76 6 +0.24 139 7.09 4 -3.09 140 6.39 7 +0.61 141 8.27 7 -1.27 142 4.38 5 +0.62 143 4.42 5 +0.58 144 4.87 5 +0.13 145 4.00 6 +2.00 146 5.09 5 -0.09 147 5.90 7 +1.10 148 4.04 5 +0.96 149 6.02 6 -0.02 150 4.60 5 +0.40 151 6.59 5 -1.59 152 5.68 7 +1.32 153 3.93 5 +1.07 154 5.09 5 -0.09 155 4.35 6 +1.65 156 4.40 5 +0.60 157 8.27 7 -1.27 158 4.33 5 +0.67 159 3.74 6 +2.26 160 6.68 5 -1.68 161 7.02 6 -1.02 162 4.50 7 +2.50 163 6.76 5 -1.76 164 5.62 6 +0.38 165 5.24 6 +0.76 166 4.83 7 +2.17 167 3.30 5 +1.70 168 4.51 6 +1.49 169 8.13 5 -3.13 170 6.19 6 -0.19 171 5.03 6 +0.97 172 4.97 5 +0.03 173 5.56 7 +1.44 174 5.66 6 +0.34 175 5.81 8 +2.19 176 5.35 5 -0.35 177 3.91 5 +1.09 178 4.64 7 +2.36 179 3.52 5 +1.48 180 5.69 6 +0.31 181 4.12 6 +1.88 182 4.29 5 +0.71 183 4.51 6 +1.49 184 6.02 6 -0.02 185 5.10 6 +0.90 186 4.59 6 +1.41 187 3.57 5 +1.43 188 4.41 5 +0.59 189 6.82 6 -0.82 190 5.95 4 -1.95 191 4.99 6 +1.01 192 3.92 5 +1.08 193 5.91 6 +0.09 194 5.68 7 +1.32 195 6.19 5 -1.19 196 5.40 6 +0.60 197 4.13 5 +0.87 198 5.99 5 -0.99 199 5.32 6 +0.68
Few questions :¶
- Can this model be used for red wines from Bordeaux and/or Beaujolais?
- What are the limitations of this model?
- What are the limitations of this dataset?
fidle.end()
End time : 22/12/24 21:22:25
Duration : 00:01:21 750ms
This notebook ends here :-)
https://fidle.cnrs.fr