{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training with missing data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us load a data set and have a look" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
feature_0feature_1feature_2feature_3target_0target_1target_2target_3
00.472986NaN0.242439-1.7007366.4651631.2477671.5623350.563148
10.753143-1.534721NaN-0.120228NaN1.4611331.682604-1.078739
2-0.8069822.871819NaN0.472457-2.545614NaN-3.3103120.749930
3NaNNaN1.342356-0.122150NaN1.0495690.504198NaN
41.012515-0.913869-1.0295301.209796NaNNaN1.132657-0.651620
...........................
950.078516-0.8372451.094795NaN3.8677491.2552170.865133NaN
960.959965-1.167800-0.3340900.8274240.5440132.263673NaN-0.057245
970.865017-0.8554050.071817-1.1259555.4172941.3490001.6000920.322496
98-0.2063090.421580NaN1.481052-3.5663681.444973-0.434093-1.330253
990.495926NaN-0.565377-0.1318051.5553371.5825800.6225290.949257
\n", "

100 rows × 8 columns

\n", "
" ], "text/plain": [ " feature_0 feature_1 feature_2 feature_3 target_0 target_1 target_2 \\\n", "0 0.472986 NaN 0.242439 -1.700736 6.465163 1.247767 1.562335 \n", "1 0.753143 -1.534721 NaN -0.120228 NaN 1.461133 1.682604 \n", "2 -0.806982 2.871819 NaN 0.472457 -2.545614 NaN -3.310312 \n", "3 NaN NaN 1.342356 -0.122150 NaN 1.049569 0.504198 \n", "4 1.012515 -0.913869 -1.029530 1.209796 NaN NaN 1.132657 \n", ".. ... ... ... ... ... ... ... \n", "95 0.078516 -0.837245 1.094795 NaN 3.867749 1.255217 0.865133 \n", "96 0.959965 -1.167800 -0.334090 0.827424 0.544013 2.263673 NaN \n", "97 0.865017 -0.855405 0.071817 -1.125955 5.417294 1.349000 1.600092 \n", "98 -0.206309 0.421580 NaN 1.481052 -3.566368 1.444973 -0.434093 \n", "99 0.495926 NaN -0.565377 -0.131805 1.555337 1.582580 0.622529 \n", "\n", " target_3 \n", "0 0.563148 \n", "1 -1.078739 \n", "2 0.749930 \n", "3 NaN \n", "4 -0.651620 \n", ".. ... \n", "95 NaN \n", "96 -0.057245 \n", "97 0.322496 \n", "98 -1.330253 \n", "99 0.949257 \n", "\n", "[100 rows x 8 columns]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "x_train = pd.read_csv(\"training_data_input.csv\")\n", "y_train = pd.read_csv(\"training_data_output.csv\")\n", "\n", "display(pd.concat([x_train, y_train], axis=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that we have a data set with 4 features and 4 targets. There seem to be lots of missing entries in both the features and the targets." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "missing entries in input data: 28%\n", "missing entries in output data: 30%\n" ] } ], "source": [ "print(\"missing entries in input data:\", \"{}%\".format(int(np.round(x_train.isna().mean().mean()*100))))\n", "print(\"missing entries in output data:\", \"{}%\".format(int(np.round(y_train.isna().mean().mean()*100))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### We want to train a linear regression. What can we do about the missing entries?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Method 1: Mean imputation\n", "Fill all the missing entries with the mean value of their respective column" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
feature_0feature_1feature_2feature_3target_0target_1target_2target_3
00.472986-0.0718680.242439-1.7007366.4651631.2477671.5623350.563148
10.753143-1.534721-0.183861-0.1202281.6559261.4611331.682604-1.078739
2-0.8069822.871819-0.1838610.472457-2.5456141.605148-3.3103120.749930
30.064623-0.0718681.342356-0.1221501.6559261.0495690.504198-0.312009
41.012515-0.913869-1.0295301.2097961.6559261.6051481.132657-0.651620
...........................
950.078516-0.8372451.094795-0.0939393.8677491.2552170.865133-0.312009
960.959965-1.167800-0.3340900.8274240.5440132.2636730.136459-0.057245
970.865017-0.8554050.071817-1.1259555.4172941.3490001.6000920.322496
98-0.2063090.421580-0.1838611.481052-3.5663681.444973-0.434093-1.330253
990.495926-0.071868-0.565377-0.1318051.5553371.5825800.6225290.949257
\n", "

100 rows × 8 columns

\n", "
" ], "text/plain": [ " feature_0 feature_1 feature_2 feature_3 target_0 target_1 target_2 \\\n", "0 0.472986 -0.071868 0.242439 -1.700736 6.465163 1.247767 1.562335 \n", "1 0.753143 -1.534721 -0.183861 -0.120228 1.655926 1.461133 1.682604 \n", "2 -0.806982 2.871819 -0.183861 0.472457 -2.545614 1.605148 -3.310312 \n", "3 0.064623 -0.071868 1.342356 -0.122150 1.655926 1.049569 0.504198 \n", "4 1.012515 -0.913869 -1.029530 1.209796 1.655926 1.605148 1.132657 \n", ".. ... ... ... ... ... ... ... \n", "95 0.078516 -0.837245 1.094795 -0.093939 3.867749 1.255217 0.865133 \n", "96 0.959965 -1.167800 -0.334090 0.827424 0.544013 2.263673 0.136459 \n", "97 0.865017 -0.855405 0.071817 -1.125955 5.417294 1.349000 1.600092 \n", "98 -0.206309 0.421580 -0.183861 1.481052 -3.566368 1.444973 -0.434093 \n", "99 0.495926 -0.071868 -0.565377 -0.131805 1.555337 1.582580 0.622529 \n", "\n", " target_3 \n", "0 0.563148 \n", "1 -1.078739 \n", "2 0.749930 \n", "3 -0.312009 \n", "4 -0.651620 \n", ".. ... \n", "95 -0.312009 \n", "96 -0.057245 \n", "97 0.322496 \n", "98 -1.330253 \n", "99 0.949257 \n", "\n", "[100 rows x 8 columns]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_train_imputed = x_train.fillna(value=x_train.mean())\n", "y_train_imputed = y_train.fillna(value=y_train.mean())\n", "\n", "display(pd.concat([x_train_imputed, y_train_imputed], axis=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and train a regressor with the imputed data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LinearRegression()" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import LinearRegression\n", "imputed_linear_regression = LinearRegression()\n", "imputed_linear_regression.fit(x_train_imputed, y_train_imputed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Method 2: Deletion\n", "Remove all examples which are missing entries" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
feature_0feature_1feature_2feature_3target_0target_1target_2target_3
500.507523-0.6183710.790793-0.8344053.2306410.8858570.816932-0.239825
54-0.8096700.500495-0.193510-0.6642030.7849971.141402-1.039911-0.013783
67-1.556314-0.6933151.624609-0.120666-2.2478480.721211-0.982030-4.287106
68-2.3485820.1672571.6999651.168899-7.5525320.995456-3.366157-5.583932
70-0.4888211.632122-0.4012251.009360-3.4508921.750701-2.670395-0.953182
92-1.160888-0.5793290.279841-0.4096020.2514551.572128-0.305981-1.609754
970.865017-0.8554050.071817-1.1259555.4172941.3490001.6000920.322496
\n", "
" ], "text/plain": [ " feature_0 feature_1 feature_2 feature_3 target_0 target_1 target_2 \\\n", "50 0.507523 -0.618371 0.790793 -0.834405 3.230641 0.885857 0.816932 \n", "54 -0.809670 0.500495 -0.193510 -0.664203 0.784997 1.141402 -1.039911 \n", "67 -1.556314 -0.693315 1.624609 -0.120666 -2.247848 0.721211 -0.982030 \n", "68 -2.348582 0.167257 1.699965 1.168899 -7.552532 0.995456 -3.366157 \n", "70 -0.488821 1.632122 -0.401225 1.009360 -3.450892 1.750701 -2.670395 \n", "92 -1.160888 -0.579329 0.279841 -0.409602 0.251455 1.572128 -0.305981 \n", "97 0.865017 -0.855405 0.071817 -1.125955 5.417294 1.349000 1.600092 \n", "\n", " target_3 \n", "50 -0.239825 \n", "54 -0.013783 \n", "67 -4.287106 \n", "68 -5.583932 \n", "70 -0.953182 \n", "92 -1.609754 \n", "97 0.322496 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "joint_dropped = pd.concat([x_train, y_train], axis=1).dropna(how=\"any\")\n", "x_train_dropped = joint_dropped[x_train.columns]\n", "y_train_dropped = joint_dropped[y_train.columns]\n", "\n", "display(pd.concat([x_train_dropped, y_train_dropped], axis=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and train a regressor with the remaining data" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LinearRegression()" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dropped_linear_regression = LinearRegression()\n", "dropped_linear_regression.fit(x_train_dropped, y_train_dropped)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Method 3: Bayesian model\n", "A Bayesian model can just treat the missing entries as unknowns " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import halerium.core as hal\n", "from halerium.core.regression import connect_via_regression\n", "\n", "g = hal.Graph(\"g\")\n", "with g:\n", " x = hal.Variable(\"x\", shape=(4,), mean=0, variance=1)\n", " y = hal.Variable(\"y\", shape=(4,), variance=0.1)\n", " connect_via_regression(\"reg\", inputs=[x], outputs=[y], order=1)\n", "\n", "# run this to show the graph in the online platform\n", "# hal.show(g)\n", "\n", "bayesian_train_model = hal.get_posterior_model(g, data={g.x: x_train, g.y: y_train}, method=\"MAP\")\n", "bayesian_post_graph = bayesian_train_model.get_posterior_graph()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Bayesian model will actually calculate an estimate for each missing entry (or rather a probability distribution)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
feature_0feature_1feature_2feature_3
00.472986NaN0.242439-1.700736
10.753143-1.534721NaN-0.120228
2-0.8069822.871819NaN0.472457
3NaNNaN1.342356-0.122150
41.012515-0.913869-1.0295301.209796
50.5018720.1388460.640761NaN
6-1.154360NaN-1.681757-1.788094
7-2.218535-0.647431NaN-0.039209
8NaNNaN-0.2539040.073252
9-0.997204-0.713856NaN-0.677945
10-0.571881-0.105862NaN0.318665
11-0.337595NaN-0.1149202.241818
12NaN0.5351360.2324900.867612
13-1.148213NaN1.000943NaN
14NaNNaN0.050523NaN
150.9435750.357644-0.0834490.677806
16NaN0.222719-1.5289851.029211
17-1.166259-1.009562-0.1052680.512022
181.407728NaN1.471234NaN
19-0.461395NaN-0.571817-0.603299
20-1.339389-1.689653NaN0.257773
211.828821-1.001002-2.0916910.146560
22-0.466351NaNNaN-1.259224
23NaN0.8026300.272391-0.969176
240.871968-1.446359NaN0.197921
25-1.365640NaN0.015935-0.080043
26-0.250803-0.565143NaN-0.782282
273.041686-0.626081NaN-0.587336
28NaN1.2320450.450889-0.641410
29NaN0.965746-1.284003-1.274572
301.5228421.4618820.037656-0.246197
31NaNNaNNaN-1.513087
32NaN0.249203NaNNaN
33NaN1.6892920.1777500.032006
341.933216-1.062095-0.7326290.842741
351.076740NaN-2.6194930.739046
360.667501NaNNaN1.407948
370.051149-0.935975-1.839109NaN
38NaN-0.561885-1.1324690.274291
390.7359120.434319-1.1200410.889095
40NaN-2.4880040.595909-2.035862
41NaN1.0576420.652769NaN
42-0.8834620.345692NaN0.410710
43NaN0.734148-0.125496NaN
440.202231NaN-1.421277-1.163588
45NaN0.0500220.765430-0.028515
46-1.205646NaN0.566844NaN
47-0.9403590.283607-0.390320-2.154124
48NaN-0.566221-0.517709NaN
49-0.603695NaN-0.959012-1.595297
500.507523-0.6183710.790793-0.834405
511.309470-1.238742NaN0.696147
521.778984-0.796317NaNNaN
530.789916NaN-2.184060-1.567268
54-0.8096700.500495-0.193510-0.664203
55NaN-1.658425NaNNaN
561.2698590.150519NaNNaN
57NaN0.164989NaN-0.115399
58NaNNaN0.4755142.639046
590.6911081.111236-0.257684-1.195951
60NaN-1.163467-3.015915NaN
610.331393-1.072815NaN-0.085521
62-0.476624-0.9637151.153983-0.444866
63NaN-0.474993-0.791428-1.693119
64-0.741163NaNNaNNaN
65NaNNaN-0.818418-0.177300
660.032502NaNNaN0.210377
67-1.556314-0.6933151.624609-0.120666
68-2.3485820.1672571.6999651.168899
690.0553380.217881NaN-0.158261
70-0.4888211.632122-0.4012251.009360
71-1.577518-0.788323-1.1564470.410545
72-0.633212-0.650858-0.9250590.143164
730.975512-0.5997550.607099-0.018603
74-0.6215600.3466101.337491NaN
750.695248NaNNaN0.763436
760.9769370.5176060.2491711.304453
771.116544NaN0.662984-0.904909
78-0.158939NaN-0.043852-0.666356
79NaNNaNNaN-1.300151
80-0.511364-0.692839NaN1.682377
81NaN0.2009620.376479-0.193338
82-0.536373NaN-0.405771NaN
83NaNNaN0.331393NaN
840.980989NaNNaNNaN
85-0.0774960.4104310.2752770.525207
86NaN2.193451-0.159283NaN
870.1682981.370530-0.728801NaN
881.2292950.7795500.215736NaN
891.2908190.455251-0.571328-0.465401
90-0.6325711.413624-0.167273NaN
91-0.5796591.1212770.619558NaN
92-1.160888-0.5793290.279841-0.409602
93NaN0.020903-0.576144-1.103720
94NaN-0.939964-0.7222520.251525
950.078516-0.8372451.094795NaN
960.959965-1.167800-0.3340900.827424
970.865017-0.8554050.071817-1.125955
98-0.2063090.421580NaN1.481052
990.495926NaN-0.565377-0.131805
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
feature_0feature_1feature_2feature_3
00.472986-0.7592070.242439-1.700736
10.753143-1.5347210.360400-0.120228
2-0.8069822.871819-0.9019170.472457
30.533772-1.1794731.342356-0.122150
41.012515-0.913869-1.0295301.209796
50.5018720.1388460.6407610.462867
6-1.1543600.000016-1.681757-1.788094
7-2.218535-0.647431-0.350765-0.039209
80.180386-0.439187-0.2539040.073252
9-0.997204-0.7138560.246540-0.677945
10-0.571881-0.1058621.4474960.318665
11-0.337595-1.091010-0.1149202.241818
12-2.7362050.5351360.2324900.867612
13-1.1482131.2405101.0009430.583911
14-0.133397-0.8198230.050523-0.434769
150.9435750.357644-0.0834490.677806
160.3941970.222719-1.5289851.029211
17-1.166259-1.009562-0.1052680.512022
181.407728-1.5508611.4712341.608291
19-0.461395-0.631948-0.571817-0.603299
20-1.339389-1.689653-0.1189970.257773
211.828821-1.001002-2.0916910.146560
22-0.4663510.2082420.408953-1.259224
23-0.4104070.8026300.272391-0.969176
240.871968-1.446359-0.2593020.197921
25-1.365640-1.6126850.015935-0.080043
26-0.250803-0.565143-0.917229-0.782282
273.041686-0.6260811.193158-0.587336
281.0737911.2320450.450889-0.641410
29-1.0035970.965746-1.284003-1.274572
301.5228421.4618820.037656-0.246197
31-0.7065640.145486-0.472127-1.513087
320.4186050.249203-1.133117-0.512646
330.4369801.6892920.1777500.032006
341.933216-1.062095-0.7326290.842741
351.0767400.069814-2.6194930.739046
360.667501-0.2194590.6354131.407948
370.051149-0.935975-1.839109-0.060533
38-0.575251-0.561885-1.1324690.274291
390.7359120.434319-1.1200410.889095
400.044935-2.4880040.595909-2.035862
41-0.0018711.0576420.6527690.003109
42-0.8834620.345692-1.6797390.410710
430.1097910.734148-0.125496-0.897293
440.202231-0.035420-1.421277-1.163588
45-1.2914950.0500220.765430-0.028515
46-1.205646-0.2075000.5668440.835726
47-0.9403590.283607-0.390320-2.154124
48-0.443188-0.566221-0.5177090.358158
49-0.6036950.184274-0.959012-1.595297
500.507523-0.6183710.790793-0.834405
511.309470-1.238742-1.1575200.696147
521.778984-0.7963170.9263921.833046
530.789916-0.119241-2.184060-1.567268
54-0.8096700.500495-0.193510-0.664203
550.654011-1.6584250.240789-0.615273
561.2698590.150519-1.137418-0.680453
57-1.3761120.164989-1.365167-0.115399
580.8946410.0949430.4755142.639046
590.6911081.111236-0.257684-1.195951
60-0.217195-1.163467-3.0159150.357342
610.331393-1.0728151.607594-0.085521
62-0.476624-0.9637151.153983-0.444866
63-0.220168-0.474993-0.791428-1.693119
64-0.741163-0.6664930.601783-1.320586
650.498025-0.692433-0.818418-0.177300
660.032502-0.3805470.5085120.210377
67-1.556314-0.6933151.624609-0.120666
68-2.3485820.1672571.6999651.168899
690.0553380.2178810.491295-0.158261
70-0.4888211.632122-0.4012251.009360
71-1.577518-0.788323-1.1564470.410545
72-0.633212-0.650858-0.9250590.143164
730.975512-0.5997550.607099-0.018603
74-0.6215600.3466101.337491-2.588218
750.6952480.5988830.7394680.763436
760.9769370.5176060.2491711.304453
771.1165440.1336070.662984-0.904909
78-0.1589390.230231-0.043852-0.666356
791.1259590.7130620.539098-1.300151
80-0.511364-0.692839-0.7472521.682377
812.3971690.2009620.376479-0.193338
82-0.5363731.193916-0.405771-1.085892
830.728069-0.0426540.3313930.364764
840.9809890.7354670.5192230.578556
85-0.0774960.4104310.2752770.525207
86-0.0140282.193451-0.1592830.273037
870.1682981.370530-0.728801-1.226624
881.2292950.7795500.215736-0.646722
891.2908190.455251-0.571328-0.465401
90-0.6325711.413624-0.167273-1.041896
91-0.5796591.1212770.619558-0.399106
92-1.160888-0.5793290.279841-0.409602
93-0.3648790.020903-0.576144-1.103720
94-0.851016-0.939964-0.7222520.251525
950.078516-0.8372451.094795-1.177459
960.959965-1.167800-0.3340900.827424
970.865017-0.8554050.071817-1.125955
98-0.2063090.421580-0.7047931.481052
990.4959260.324680-0.565377-0.131805
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_train_bayesian_imputed = bayesian_train_model.get_means(g.x)\n", "x_train_bayesian_imputed = pd.DataFrame(data=x_train_bayesian_imputed, columns=x_train.columns)\n", "\n", "from plots import display_side_by_side\n", "display_side_by_side(x_train, x_train_bayesian_imputed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare the performance on test data" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "x_test = pd.read_csv(\"testing_data_input.csv\").values\n", "y_test = pd.read_csv(\"testing_data_output.csv\").values\n", "\n", "imputed_prediction = imputed_linear_regression.predict(x_test)\n", "\n", "dropped_prediction = dropped_linear_regression.predict(x_test)\n", "\n", "bayesian_prediction_model = hal.get_generative_model(bayesian_post_graph, data={g.x: x_test})\n", "bayesian_prediction = bayesian_prediction_model.get_means(g.y)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Bayesian model')" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import pylab as pl\n", "\n", "pl.figure(figsize=(12, 12))\n", "\n", "dark_erium_green = '#00b34a'\n", "erium_blue = '#002a43'\n", "\n", "ax = pl.subplot(2,2,1)\n", "ax.set_aspect(\"equal\")\n", "ax.scatter(y_test[:,0], imputed_prediction[:,0], color='#00b34a')\n", "ax.plot([np.min(y_test[:,0]),np.max(y_test[:,0])], [np.min(y_test[:,0]),np.max(y_test[:,0])], ls=\":\", color=\"k\")\n", "ax.set_xlabel(\"real output value\")\n", "ax.set_ylabel(\"predicted output value\")\n", "ax.set_title(\"training with imputed data\")\n", "\n", "ax = pl.subplot(2,2,2)\n", "ax.set_aspect(\"equal\")\n", "ax.scatter(y_test[:,0], dropped_prediction[:,0], color='#00b34a')\n", "ax.plot([np.min(y_test[:,0]),np.max(y_test[:,0])], [np.min(y_test[:,0]),np.max(y_test[:,0])], ls=\":\", color=\"k\")\n", "ax.set_xlabel(\"real output value\")\n", "ax.set_ylabel(\"predicted output value\")\n", "ax.set_title(\"training with missing rows dropped\")\n", "\n", "ax = pl.subplot(2,1,2)\n", "ax.set_aspect(\"equal\")\n", "ax.scatter(y_test[:,0], bayesian_prediction[:,0], color='#00b34a')\n", "ax.plot([np.min(y_test[:,0]),np.max(y_test[:,0])], [np.min(y_test[:,0]),np.max(y_test[:,0])], ls=\":\", color=\"k\")\n", "ax.set_xlabel(\"real output value\")\n", "ax.set_ylabel(\"predicted output value\")\n", "ax.set_title(\"Bayesian model\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, Bayesian models offer interesting advantages when dealing with missing data. Missing data often occur in industrial environments, e.g. when a sensor output could not be recorded or the output was corrupted." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }