[K3BHPD1] - Regression with a Dense Network (DNN)¶
Simple example of a regression with the dataset Boston Housing Prices Dataset (BHPD)Objectives :¶
- Predicts housing prices from a set of house features.
- Understanding the principle and the architecture of a regression with a dense neural network
The Boston Housing Prices Dataset consists of price of houses in various places in Boston.
Alongside with price, the dataset also provide theses informations :
- CRIM: This is the per capita crime rate by town
- ZN: This is the proportion of residential land zoned for lots larger than 25,000 sq.ft
- INDUS: This is the proportion of non-retail business acres per town
- CHAS: This is the Charles River dummy variable (this is equal to 1 if tract bounds river; 0 otherwise)
- NOX: This is the nitric oxides concentration (parts per 10 million)
- RM: This is the average number of rooms per dwelling
- AGE: This is the proportion of owner-occupied units built prior to 1940
- DIS: This is the weighted distances to five Boston employment centers
- RAD: This is the index of accessibility to radial highways
- TAX: This is the full-value property-tax rate per 10,000 dollars
- PTRATIO: This is the pupil-teacher ratio by town
- B: This is calculated as 1000(Bk — 0.63)^2, where Bk is the proportion of people of African American descent by town
- LSTAT: This is the percentage lower status of the population
- MEDV: This is the median value of owner-occupied homes in 1000 dollars
What we're going to do :¶
- Retrieve data
- Preparing the data
- Build a model
- Train the model
- Evaluate the result
Step 1 - Import and init¶
You can also adjust the verbosity by changing the value of TF_CPP_MIN_LOG_LEVEL :
- 0 = all messages are logged (default)
- 1 = INFO messages are not printed.
- 2 = INFO and WARNING messages are not printed.
- 3 = INFO , WARNING and ERROR messages are not printed.
In [1]:
import os
os.environ['KERAS_BACKEND'] = 'torch'
import keras
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os,sys
import fidle
# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3BHPD1')
FIDLE - Environment initialization
Version : 2.3.0 Run id : K3BHPD1 Run dir : ./run/K3BHPD1 Datasets dir : /gpfswork/rech/mlh/uja62cb/fidle-project/datasets-fidle Start time : 03/03/24 21:03:34 Hostname : r3i6n3 (Linux) Tensorflow log level : Warning + Error (=1) Update keras cache : False Update torch cache : False Save figs : ./run/K3BHPD1/figs (True) keras : 3.0.4 numpy : 1.24.4 sklearn : 1.3.2 yaml : 6.0.1 matplotlib : 3.8.2 pandas : 2.1.3 torch : 2.1.1
Verbosity during training :
- 0 = silent
- 1 = progress bar
- 2 = one line per epoch
In [2]:
fit_verbosity = 1
Override parameters (batch mode) - Just forget this cell
In [3]:
fidle.override('fit_verbosity')
** Overrided parameters : ** fit_verbosity : 2
Step 2 - Retrieve data¶
2.1 - Option 1 : From Keras¶
Boston housing is a famous historic dataset, so we can get it directly from Keras datasets
In [4]:
# (x_train, y_train), (x_test, y_test) = keras.datasets.boston_housing.load_data(test_split=0.2, seed=113)
2.2 - Option 2 : From a csv file¶
More fun !
In [5]:
data = pd.read_csv(f'{datasets_dir}/BHPD/origine/BostonHousing.csv', header=0)
display(data.head(5).style.format("{0:.2f}").set_caption("Few lines of the dataset :"))
print('Missing Data : ',data.isna().sum().sum(), ' Shape is : ', data.shape)
crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | b | lstat | medv | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.01 | 18.00 | 2.31 | 0.00 | 0.54 | 6.58 | 65.20 | 4.09 | 1.00 | 296.00 | 15.30 | 396.90 | 4.98 | 24.00 |
1 | 0.03 | 0.00 | 7.07 | 0.00 | 0.47 | 6.42 | 78.90 | 4.97 | 2.00 | 242.00 | 17.80 | 396.90 | 9.14 | 21.60 |
2 | 0.03 | 0.00 | 7.07 | 0.00 | 0.47 | 7.18 | 61.10 | 4.97 | 2.00 | 242.00 | 17.80 | 392.83 | 4.03 | 34.70 |
3 | 0.03 | 0.00 | 2.18 | 0.00 | 0.46 | 7.00 | 45.80 | 6.06 | 3.00 | 222.00 | 18.70 | 394.63 | 2.94 | 33.40 |
4 | 0.07 | 0.00 | 2.18 | 0.00 | 0.46 | 7.15 | 54.20 | 6.06 | 3.00 | 222.00 | 18.70 | 396.90 | 5.33 | 36.20 |
Missing Data : 0 Shape is : (506, 14)
In [6]:
# ---- Shuffle and Split => train, test
#
data = data.sample(frac=1., axis=0)
data_train = data.sample(frac=0.7, axis=0)
data_test = data.drop(data_train.index)
# ---- Split => x,y (medv is price)
#
x_train = data_train.drop('medv', axis=1)
y_train = data_train['medv']
x_test = data_test.drop('medv', axis=1)
y_test = data_test['medv']
print('Original data shape was : ',data.shape)
print('x_train : ',x_train.shape, 'y_train : ',y_train.shape)
print('x_test : ',x_test.shape, 'y_test : ',y_test.shape)
Original data shape was : (506, 14) x_train : (354, 13) y_train : (354,) x_test : (152, 13) y_test : (152,)
3.2 - Data normalization¶
Note :
- All input data must be normalized, train and test.
- To do this we will subtract the mean and divide by the standard deviation.
- But test data should not be used in any way, even for normalization.
- The mean and the standard deviation will therefore only be calculated with the train data.
In [7]:
display(x_train.describe().style.format("{0:.2f}").set_caption("Before normalization :"))
mean = x_train.mean()
std = x_train.std()
x_train = (x_train - mean) / std
x_test = (x_test - mean) / std
display(x_train.describe().style.format("{0:.2f}").set_caption("After normalization :"))
display(x_train.head(5).style.format("{0:.2f}").set_caption("Few lines of the dataset :"))
x_train, y_train = np.array(x_train), np.array(y_train)
x_test, y_test = np.array(x_test), np.array(y_test)
crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | b | lstat | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 |
mean | 3.38 | 10.66 | 11.18 | 0.08 | 0.56 | 6.26 | 68.36 | 3.79 | 9.61 | 408.45 | 18.53 | 351.78 | 12.71 |
std | 7.34 | 23.04 | 6.87 | 0.27 | 0.12 | 0.73 | 28.09 | 2.12 | 8.72 | 169.72 | 2.12 | 96.20 | 7.08 |
min | 0.01 | 0.00 | 1.21 | 0.00 | 0.39 | 3.56 | 2.90 | 1.13 | 1.00 | 187.00 | 12.60 | 2.52 | 2.47 |
25% | 0.08 | 0.00 | 5.19 | 0.00 | 0.45 | 5.88 | 44.58 | 2.10 | 4.00 | 279.25 | 17.40 | 371.62 | 7.04 |
50% | 0.28 | 0.00 | 9.69 | 0.00 | 0.54 | 6.17 | 76.85 | 3.17 | 5.00 | 330.00 | 19.10 | 390.47 | 11.46 |
75% | 3.68 | 0.00 | 18.10 | 0.00 | 0.62 | 6.60 | 93.88 | 5.11 | 24.00 | 666.00 | 20.20 | 396.04 | 16.95 |
max | 73.53 | 100.00 | 27.74 | 1.00 | 0.87 | 8.78 | 100.00 | 12.13 | 24.00 | 711.00 | 22.00 | 396.90 | 37.97 |
crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | b | lstat | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 | 354.00 |
mean | -0.00 | 0.00 | 0.00 | -0.00 | 0.00 | -0.00 | -0.00 | -0.00 | -0.00 | -0.00 | 0.00 | -0.00 | 0.00 |
std | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
min | -0.46 | -0.46 | -1.45 | -0.29 | -1.45 | -3.72 | -2.33 | -1.25 | -0.99 | -1.30 | -2.80 | -3.63 | -1.45 |
25% | -0.45 | -0.46 | -0.87 | -0.29 | -0.88 | -0.53 | -0.85 | -0.79 | -0.64 | -0.76 | -0.53 | 0.21 | -0.80 |
50% | -0.42 | -0.46 | -0.22 | -0.29 | -0.16 | -0.13 | 0.30 | -0.29 | -0.53 | -0.46 | 0.27 | 0.40 | -0.18 |
75% | 0.04 | -0.46 | 1.01 | -0.29 | 0.56 | 0.47 | 0.91 | 0.63 | 1.65 | 1.52 | 0.79 | 0.46 | 0.60 |
max | 9.56 | 3.88 | 2.41 | 3.48 | 2.64 | 3.48 | 1.13 | 3.93 | 1.65 | 1.78 | 1.64 | 0.47 | 3.57 |
crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | b | lstat | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
407 | 1.17 | -0.46 | 1.01 | -0.29 | 0.85 | -0.90 | 1.13 | -1.18 | 1.65 | 1.52 | 0.79 | -0.20 | -0.08 |
315 | -0.43 | -0.46 | -0.19 | -0.29 | -0.11 | -0.76 | 0.33 | 0.07 | -0.64 | -0.62 | -0.06 | 0.46 | -0.17 |
138 | -0.43 | -0.46 | 1.56 | -0.29 | 0.56 | -0.55 | 1.06 | -1.00 | -0.64 | 0.17 | 1.26 | 0.42 | 1.22 |
307 | -0.45 | 0.97 | -1.31 | -0.29 | -0.72 | 0.81 | 0.07 | -0.29 | -0.30 | -1.10 | -0.06 | 0.47 | -0.73 |
78 | -0.45 | -0.46 | 0.24 | -0.29 | -1.01 | -0.04 | -0.52 | 0.58 | -0.53 | -0.06 | 0.08 | 0.36 | -0.05 |
In [8]:
def get_model_v1(shape):
model = keras.models.Sequential()
model.add(keras.layers.Input(shape, name="InputLayer"))
model.add(keras.layers.Dense(32, activation='relu', name='Dense_n1'))
model.add(keras.layers.Dense(64, activation='relu', name='Dense_n2'))
model.add(keras.layers.Dense(32, activation='relu', name='Dense_n3'))
model.add(keras.layers.Dense(1, name='Output'))
model.compile(optimizer = 'adam',
loss = 'mse',
metrics = ['mae', 'mse'] )
return model
In [9]:
model=get_model_v1( (13,) )
model.summary()
# img=keras.utils.plot_model( model, to_file='./run/model.png', show_shapes=True, show_layer_names=True, dpi=96)
# display(img)
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ Dense_n1 (Dense) │ (None, 32) │ 448 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ Dense_n2 (Dense) │ (None, 64) │ 2,112 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ Dense_n3 (Dense) │ (None, 32) │ 2,080 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ Output (Dense) │ (None, 1) │ 33 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 4,673 (18.25 KB)
Trainable params: 4,673 (18.25 KB)
Non-trainable params: 0 (0.00 B)
5.2 - Train it¶
In [10]:
history = model.fit(x_train,
y_train,
epochs = 60,
batch_size = 10,
verbose = fit_verbosity,
validation_data = (x_test, y_test))
Epoch 1/60 36/36 - 1s - 20ms/step - loss: 543.1082 - mae: 21.3474 - mse: 543.1082 - val_loss: 530.8257 - val_mae: 20.4749 - val_mse: 530.8257 Epoch 2/60 36/36 - 0s - 7ms/step - loss: 333.9670 - mae: 15.8058 - mse: 333.9670 - val_loss: 187.0211 - val_mae: 10.6376 - val_mse: 187.0211 Epoch 3/60 36/36 - 0s - 7ms/step - loss: 67.2444 - mae: 6.2388 - mse: 67.2444 - val_loss: 83.0897 - val_mae: 5.4011 - val_mse: 83.0897 Epoch 4/60 36/36 - 0s - 7ms/step - loss: 29.2588 - mae: 3.8964 - mse: 29.2588 - val_loss: 64.2306 - val_mae: 4.6466 - val_mse: 64.2306 Epoch 5/60 36/36 - 0s - 7ms/step - loss: 22.1348 - mae: 3.2935 - mse: 22.1348 - val_loss: 55.9101 - val_mae: 4.2648 - val_mse: 55.9101 Epoch 6/60 36/36 - 0s - 7ms/step - loss: 19.2369 - mae: 3.0147 - mse: 19.2369 - val_loss: 50.4202 - val_mae: 4.1650 - val_mse: 50.4202 Epoch 7/60 36/36 - 0s - 7ms/step - loss: 17.4199 - mae: 2.8746 - mse: 17.4199 - val_loss: 45.7900 - val_mae: 4.1948 - val_mse: 45.7900 Epoch 8/60 36/36 - 0s - 7ms/step - loss: 16.4000 - mae: 2.8069 - mse: 16.4000 - val_loss: 42.2878 - val_mae: 3.8906 - val_mse: 42.2878 Epoch 9/60 36/36 - 0s - 7ms/step - loss: 14.7042 - mae: 2.6721 - mse: 14.7042 - val_loss: 39.6995 - val_mae: 3.9133 - val_mse: 39.6995 Epoch 10/60 36/36 - 0s - 7ms/step - loss: 14.0582 - mae: 2.6152 - mse: 14.0582 - val_loss: 37.5213 - val_mae: 3.5730 - val_mse: 37.5213 Epoch 11/60 36/36 - 0s - 7ms/step - loss: 13.5775 - mae: 2.5909 - mse: 13.5775 - val_loss: 34.2684 - val_mae: 3.5854 - val_mse: 34.2684 Epoch 12/60 36/36 - 0s - 7ms/step - loss: 12.9532 - mae: 2.4971 - mse: 12.9532 - val_loss: 32.6236 - val_mae: 3.5739 - val_mse: 32.6236 Epoch 13/60 36/36 - 0s - 7ms/step - loss: 11.9487 - mae: 2.4161 - mse: 11.9487 - val_loss: 33.5929 - val_mae: 3.5055 - val_mse: 33.5929 Epoch 14/60 36/36 - 0s - 7ms/step - loss: 11.7234 - mae: 2.4240 - mse: 11.7234 - val_loss: 32.2932 - val_mae: 3.4785 - val_mse: 32.2932 Epoch 15/60 36/36 - 0s - 7ms/step - loss: 11.1487 - mae: 2.3635 - mse: 11.1487 - val_loss: 30.6134 - val_mae: 3.4204 - val_mse: 30.6134 Epoch 16/60 36/36 - 0s - 7ms/step - loss: 10.8006 - mae: 2.3213 - mse: 10.8006 - val_loss: 29.9556 - val_mae: 3.3401 - val_mse: 29.9556 Epoch 17/60 36/36 - 0s - 7ms/step - loss: 10.4818 - mae: 2.3065 - mse: 10.4818 - val_loss: 29.4162 - val_mae: 3.3251 - val_mse: 29.4162 Epoch 18/60 36/36 - 0s - 7ms/step - loss: 10.2679 - mae: 2.2538 - mse: 10.2679 - val_loss: 28.6373 - val_mae: 3.4103 - val_mse: 28.6373 Epoch 19/60 36/36 - 0s - 7ms/step - loss: 9.9521 - mae: 2.2570 - mse: 9.9521 - val_loss: 28.6532 - val_mae: 3.2360 - val_mse: 28.6532 Epoch 20/60 36/36 - 0s - 7ms/step - loss: 10.3724 - mae: 2.2970 - mse: 10.3724 - val_loss: 28.9017 - val_mae: 3.1968 - val_mse: 28.9017 Epoch 21/60 36/36 - 0s - 7ms/step - loss: 9.6463 - mae: 2.2016 - mse: 9.6463 - val_loss: 28.5355 - val_mae: 3.0586 - val_mse: 28.5355 Epoch 22/60 36/36 - 0s - 7ms/step - loss: 9.5152 - mae: 2.2188 - mse: 9.5152 - val_loss: 27.2590 - val_mae: 3.3128 - val_mse: 27.2590 Epoch 23/60 36/36 - 0s - 7ms/step - loss: 9.2451 - mae: 2.1551 - mse: 9.2451 - val_loss: 26.6165 - val_mae: 3.3330 - val_mse: 26.6165 Epoch 24/60 36/36 - 0s - 7ms/step - loss: 8.9698 - mae: 2.1389 - mse: 8.9698 - val_loss: 26.4686 - val_mae: 3.1454 - val_mse: 26.4686 Epoch 25/60 36/36 - 0s - 7ms/step - loss: 8.8697 - mae: 2.0963 - mse: 8.8697 - val_loss: 26.2414 - val_mae: 3.2706 - val_mse: 26.2414 Epoch 26/60 36/36 - 0s - 7ms/step - loss: 9.1039 - mae: 2.1769 - mse: 9.1039 - val_loss: 27.0105 - val_mae: 3.2398 - val_mse: 27.0105 Epoch 27/60 36/36 - 0s - 7ms/step - loss: 8.9789 - mae: 2.1527 - mse: 8.9789 - val_loss: 25.8938 - val_mae: 3.3328 - val_mse: 25.8938 Epoch 28/60 36/36 - 0s - 7ms/step - loss: 8.1949 - mae: 2.0133 - mse: 8.1949 - val_loss: 25.1584 - val_mae: 3.2859 - val_mse: 25.1584 Epoch 29/60 36/36 - 0s - 7ms/step - loss: 8.2697 - mae: 2.0489 - mse: 8.2697 - val_loss: 26.2097 - val_mae: 3.2547 - val_mse: 26.2097 Epoch 30/60 36/36 - 0s - 7ms/step - loss: 8.0759 - mae: 2.0567 - mse: 8.0759 - val_loss: 24.3104 - val_mae: 3.2883 - val_mse: 24.3104 Epoch 31/60 36/36 - 0s - 7ms/step - loss: 7.8675 - mae: 1.9980 - mse: 7.8675 - val_loss: 24.1865 - val_mae: 3.0332 - val_mse: 24.1865 Epoch 32/60 36/36 - 0s - 7ms/step - loss: 7.7421 - mae: 1.9648 - mse: 7.7421 - val_loss: 23.0636 - val_mae: 3.2237 - val_mse: 23.0636 Epoch 33/60 36/36 - 0s - 7ms/step - loss: 7.9463 - mae: 2.0408 - mse: 7.9463 - val_loss: 25.7716 - val_mae: 3.0403 - val_mse: 25.7716 Epoch 34/60 36/36 - 0s - 7ms/step - loss: 7.7665 - mae: 2.0182 - mse: 7.7665 - val_loss: 24.4834 - val_mae: 2.8762 - val_mse: 24.4834 Epoch 35/60 36/36 - 0s - 7ms/step - loss: 7.4877 - mae: 1.9551 - mse: 7.4877 - val_loss: 23.0045 - val_mae: 2.9572 - val_mse: 23.0045 Epoch 36/60 36/36 - 0s - 7ms/step - loss: 7.1787 - mae: 1.9141 - mse: 7.1787 - val_loss: 22.0238 - val_mae: 2.9616 - val_mse: 22.0238 Epoch 37/60 36/36 - 0s - 7ms/step - loss: 6.9573 - mae: 1.8987 - mse: 6.9573 - val_loss: 22.9940 - val_mae: 2.8356 - val_mse: 22.9940 Epoch 38/60 36/36 - 0s - 7ms/step - loss: 6.9769 - mae: 1.8959 - mse: 6.9769 - val_loss: 23.8300 - val_mae: 2.9075 - val_mse: 23.8300 Epoch 39/60 36/36 - 0s - 7ms/step - loss: 6.9622 - mae: 1.8937 - mse: 6.9622 - val_loss: 22.7643 - val_mae: 3.0267 - val_mse: 22.7643 Epoch 40/60 36/36 - 0s - 7ms/step - loss: 6.7445 - mae: 1.8647 - mse: 6.7445 - val_loss: 22.3702 - val_mae: 2.9571 - val_mse: 22.3702 Epoch 41/60 36/36 - 0s - 7ms/step - loss: 6.8661 - mae: 1.8730 - mse: 6.8661 - val_loss: 22.0981 - val_mae: 2.7624 - val_mse: 22.0981 Epoch 42/60 36/36 - 0s - 7ms/step - loss: 7.0245 - mae: 1.9152 - mse: 7.0245 - val_loss: 20.1928 - val_mae: 2.9106 - val_mse: 20.1928 Epoch 43/60 36/36 - 0s - 7ms/step - loss: 6.4436 - mae: 1.7931 - mse: 6.4436 - val_loss: 19.8561 - val_mae: 2.9603 - val_mse: 19.8561 Epoch 44/60 36/36 - 0s - 7ms/step - loss: 6.5864 - mae: 1.8390 - mse: 6.5864 - val_loss: 19.8864 - val_mae: 3.0288 - val_mse: 19.8864 Epoch 45/60 36/36 - 0s - 7ms/step - loss: 6.1864 - mae: 1.8175 - mse: 6.1864 - val_loss: 18.9116 - val_mae: 2.7530 - val_mse: 18.9116 Epoch 46/60 36/36 - 0s - 7ms/step - loss: 6.6438 - mae: 1.8382 - mse: 6.6438 - val_loss: 19.5592 - val_mae: 2.9954 - val_mse: 19.5592 Epoch 47/60 36/36 - 0s - 7ms/step - loss: 6.1429 - mae: 1.7933 - mse: 6.1429 - val_loss: 18.7630 - val_mae: 2.6702 - val_mse: 18.7630 Epoch 48/60 36/36 - 0s - 7ms/step - loss: 5.9220 - mae: 1.7629 - mse: 5.9220 - val_loss: 19.4392 - val_mae: 2.7110 - val_mse: 19.4392 Epoch 49/60 36/36 - 0s - 7ms/step - loss: 5.7460 - mae: 1.7067 - mse: 5.7460 - val_loss: 18.0015 - val_mae: 2.8765 - val_mse: 18.0015 Epoch 50/60 36/36 - 0s - 7ms/step - loss: 5.8500 - mae: 1.7713 - mse: 5.8500 - val_loss: 17.7976 - val_mae: 2.7511 - val_mse: 17.7976 Epoch 51/60 36/36 - 0s - 7ms/step - loss: 5.6177 - mae: 1.7191 - mse: 5.6177 - val_loss: 17.0336 - val_mae: 2.6736 - val_mse: 17.0336 Epoch 52/60 36/36 - 0s - 7ms/step - loss: 5.4596 - mae: 1.6956 - mse: 5.4596 - val_loss: 17.9148 - val_mae: 2.7196 - val_mse: 17.9148 Epoch 53/60 36/36 - 0s - 7ms/step - loss: 5.5194 - mae: 1.6889 - mse: 5.5194 - val_loss: 17.1585 - val_mae: 2.7743 - val_mse: 17.1585 Epoch 54/60 36/36 - 0s - 7ms/step - loss: 5.3435 - mae: 1.7139 - mse: 5.3435 - val_loss: 17.4382 - val_mae: 2.7031 - val_mse: 17.4382 Epoch 55/60 36/36 - 0s - 7ms/step - loss: 5.2430 - mae: 1.6497 - mse: 5.2430 - val_loss: 18.8947 - val_mae: 2.8949 - val_mse: 18.8947 Epoch 56/60 36/36 - 0s - 7ms/step - loss: 5.2768 - mae: 1.7162 - mse: 5.2768 - val_loss: 17.7002 - val_mae: 3.1200 - val_mse: 17.7002 Epoch 57/60 36/36 - 0s - 7ms/step - loss: 5.4849 - mae: 1.7319 - mse: 5.4849 - val_loss: 16.6518 - val_mae: 2.5902 - val_mse: 16.6518 Epoch 58/60 36/36 - 0s - 7ms/step - loss: 5.0016 - mae: 1.6098 - mse: 5.0016 - val_loss: 16.5865 - val_mae: 2.8223 - val_mse: 16.5865 Epoch 59/60 36/36 - 0s - 7ms/step - loss: 4.7166 - mae: 1.5636 - mse: 4.7166 - val_loss: 17.2100 - val_mae: 3.0379 - val_mse: 17.2100 Epoch 60/60 36/36 - 0s - 7ms/step - loss: 5.1727 - mae: 1.6823 - mse: 5.1727 - val_loss: 15.2430 - val_mae: 2.6203 - val_mse: 15.2430
In [11]:
score = model.evaluate(x_test, y_test, verbose=0)
print('x_test / loss : {:5.4f}'.format(score[0]))
print('x_test / mae : {:5.4f}'.format(score[1]))
print('x_test / mse : {:5.4f}'.format(score[2]))
x_test / loss : 13.1545 x_test / mae : 2.4604 x_test / mse : 13.1545
6.2 - Training history¶
What was the best result during our training ?
In [12]:
df=pd.DataFrame(data=history.history)
display(df)
loss | mae | mse | val_loss | val_mae | val_mse | |
---|---|---|---|---|---|---|
0 | 543.108215 | 21.347355 | 543.108215 | 530.825745 | 20.474915 | 530.825745 |
1 | 333.966980 | 15.805815 | 333.966980 | 187.021072 | 10.637609 | 187.021072 |
2 | 67.244400 | 6.238824 | 67.244400 | 83.089676 | 5.401127 | 83.089676 |
3 | 29.258842 | 3.896393 | 29.258842 | 64.230644 | 4.646584 | 64.230644 |
4 | 22.134773 | 3.293489 | 22.134773 | 55.910118 | 4.264829 | 55.910118 |
5 | 19.236858 | 3.014724 | 19.236858 | 50.420227 | 4.164992 | 50.420227 |
6 | 17.419882 | 2.874641 | 17.419882 | 45.790028 | 4.194780 | 45.790028 |
7 | 16.400032 | 2.806903 | 16.400032 | 42.287830 | 3.890553 | 42.287830 |
8 | 14.704198 | 2.672128 | 14.704198 | 39.699509 | 3.913334 | 39.699509 |
9 | 14.058215 | 2.615244 | 14.058215 | 37.521290 | 3.572979 | 37.521290 |
10 | 13.577485 | 2.590885 | 13.577485 | 34.268379 | 3.585387 | 34.268379 |
11 | 12.953210 | 2.497066 | 12.953210 | 32.623631 | 3.573913 | 32.623631 |
12 | 11.948710 | 2.416056 | 11.948710 | 33.592857 | 3.505531 | 33.592857 |
13 | 11.723386 | 2.424000 | 11.723386 | 32.293243 | 3.478509 | 32.293243 |
14 | 11.148730 | 2.363535 | 11.148730 | 30.613375 | 3.420393 | 30.613375 |
15 | 10.800570 | 2.321263 | 10.800570 | 29.955589 | 3.340061 | 29.955589 |
16 | 10.481842 | 2.306483 | 10.481842 | 29.416155 | 3.325119 | 29.416155 |
17 | 10.267911 | 2.253756 | 10.267911 | 28.637268 | 3.410266 | 28.637268 |
18 | 9.952110 | 2.257045 | 9.952110 | 28.653204 | 3.236047 | 28.653204 |
19 | 10.372357 | 2.296954 | 10.372357 | 28.901722 | 3.196816 | 28.901722 |
20 | 9.646282 | 2.201633 | 9.646282 | 28.535492 | 3.058609 | 28.535492 |
21 | 9.515194 | 2.218797 | 9.515194 | 27.258991 | 3.312815 | 27.258991 |
22 | 9.245111 | 2.155149 | 9.245111 | 26.616455 | 3.333007 | 26.616455 |
23 | 8.969801 | 2.138906 | 8.969801 | 26.468590 | 3.145396 | 26.468590 |
24 | 8.869717 | 2.096283 | 8.869717 | 26.241409 | 3.270614 | 26.241409 |
25 | 9.103916 | 2.176872 | 9.103916 | 27.010534 | 3.239777 | 27.010534 |
26 | 8.978923 | 2.152682 | 8.978923 | 25.893824 | 3.332818 | 25.893824 |
27 | 8.194920 | 2.013315 | 8.194920 | 25.158413 | 3.285944 | 25.158413 |
28 | 8.269663 | 2.048867 | 8.269663 | 26.209673 | 3.254708 | 26.209673 |
29 | 8.075867 | 2.056707 | 8.075867 | 24.310368 | 3.288310 | 24.310368 |
30 | 7.867470 | 1.997968 | 7.867470 | 24.186493 | 3.033187 | 24.186493 |
31 | 7.742054 | 1.964831 | 7.742054 | 23.063612 | 3.223681 | 23.063612 |
32 | 7.946270 | 2.040770 | 7.946270 | 25.771570 | 3.040349 | 25.771570 |
33 | 7.766536 | 2.018191 | 7.766536 | 24.483360 | 2.876218 | 24.483360 |
34 | 7.487707 | 1.955132 | 7.487707 | 23.004549 | 2.957219 | 23.004549 |
35 | 7.178676 | 1.914056 | 7.178676 | 22.023815 | 2.961606 | 22.023815 |
36 | 6.957317 | 1.898735 | 6.957317 | 22.993988 | 2.835628 | 22.993988 |
37 | 6.976868 | 1.895854 | 6.976868 | 23.829964 | 2.907541 | 23.829964 |
38 | 6.962208 | 1.893683 | 6.962208 | 22.764338 | 3.026685 | 22.764338 |
39 | 6.744496 | 1.864698 | 6.744496 | 22.370192 | 2.957121 | 22.370192 |
40 | 6.866076 | 1.872955 | 6.866076 | 22.098078 | 2.762429 | 22.098078 |
41 | 7.024477 | 1.915174 | 7.024477 | 20.192778 | 2.910629 | 20.192778 |
42 | 6.443647 | 1.793054 | 6.443647 | 19.856092 | 2.960336 | 19.856092 |
43 | 6.586435 | 1.839020 | 6.586435 | 19.886429 | 3.028769 | 19.886429 |
44 | 6.186442 | 1.817471 | 6.186442 | 18.911575 | 2.753022 | 18.911575 |
45 | 6.643803 | 1.838165 | 6.643803 | 19.559246 | 2.995370 | 19.559246 |
46 | 6.142876 | 1.793305 | 6.142876 | 18.763016 | 2.670213 | 18.763016 |
47 | 5.922019 | 1.762872 | 5.922019 | 19.439199 | 2.710960 | 19.439199 |
48 | 5.746012 | 1.706701 | 5.746012 | 18.001497 | 2.876452 | 18.001497 |
49 | 5.849993 | 1.771330 | 5.849993 | 17.797596 | 2.751129 | 17.797596 |
50 | 5.617665 | 1.719082 | 5.617665 | 17.033596 | 2.673640 | 17.033596 |
51 | 5.459576 | 1.695600 | 5.459576 | 17.914778 | 2.719574 | 17.914778 |
52 | 5.519361 | 1.688948 | 5.519361 | 17.158506 | 2.774285 | 17.158506 |
53 | 5.343516 | 1.713885 | 5.343516 | 17.438204 | 2.703115 | 17.438204 |
54 | 5.242969 | 1.649727 | 5.242969 | 18.894669 | 2.894852 | 18.894669 |
55 | 5.276813 | 1.716197 | 5.276813 | 17.700230 | 3.120050 | 17.700230 |
56 | 5.484945 | 1.731909 | 5.484945 | 16.651814 | 2.590214 | 16.651814 |
57 | 5.001614 | 1.609809 | 5.001614 | 16.586475 | 2.822302 | 16.586475 |
58 | 4.716555 | 1.563606 | 4.716555 | 17.209959 | 3.037879 | 17.209959 |
59 | 5.172710 | 1.682285 | 5.172710 | 15.243008 | 2.620307 | 15.243008 |
In [13]:
print("min( val_mae ) : {:.4f}".format( min(history.history["val_mae"]) ) )
min( val_mae ) : 2.5902
In [14]:
fidle.scrawler.history( history, plot={'MSE' :['mse', 'val_mse'],
'MAE' :['mae', 'val_mae'],
'LOSS':['loss','val_loss']}, save_as='01-history')
Saved: ./run/K3BHPD1/figs/01-history_0
Saved: ./run/K3BHPD1/figs/01-history_1
Saved: ./run/K3BHPD1/figs/01-history_2
Step 7 - Make a prediction¶
The data must be normalized with the parameters (mean, std) previously used.
In [15]:
my_data = [ 1.26425925, -0.48522739, 1.0436489 , -0.23112788, 1.37120745,
-2.14308942, 1.13489104, -1.06802005, 1.71189006, 1.57042287,
0.77859951, 0.14769795, 2.7585581 ]
real_price = 10.4
my_data=np.array(my_data).reshape(1,13)
In [16]:
predictions = model.predict( my_data, verbose=fit_verbosity )
print("Prediction : {:.2f} K$".format(predictions[0][0]))
print("Reality : {:.2f} K$".format(real_price))
1/1 - 0s - 3ms/step Prediction : 10.92 K$ Reality : 10.40 K$
In [17]:
fidle.end()
End time : 03/03/24 21:03:53
Duration : 00:00:20 646ms
This notebook ends here :-)
https://fidle.cnrs.fr