{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# More on training models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In Halerium, a model with data for some of its variables usually needs to be \"solved\" before it can be employed to compute statistical properties consistent with the provided data. This solving step may also be referred to as \"fitting the model to the data\" (which is roughly what happens under the hood) or (in particular in the context of machine learning) as \"training\".\n", "\n", "Models can be created and trained using the model factory function `get_posterior_model`, see [Training with model factories](#training_with_model_factories).\n", "\n", "Models can also be created and trained using the halerium.model classes and their methods directly, see [Training models directly](#training_models).\n", "\n", "Moreover, models can be trained using The `Trainer` class, see [Trainer](./02_trainer.ipynb).\n", "\n", "The result of the training is the posterior graph representing the trained model. This graph can be used to compute predictions from the trained model, see [Outlook](#predictions), and [Predictor](./../02_objectives/03_predictor.ipynb)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to import the following packages, classes, and functions." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# for handling data:\n", "import numpy as np\n", "\n", "# for plotting:\n", "import matplotlib.pyplot as plt\n", "\n", "# for graphs:\n", "from halerium.core import Graph, Variable, StaticVariable, show\n", "\n", "# for creating models with a factory:\n", "from halerium.core.model import get_posterior_model\n", "\n", "# for creating and using models directly:\n", "from halerium.core.model import MAPModel, MAPFisherModel, ForwardModel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The graph and data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For training models, we need a graph representing the prior statistical properties and connections between variables, and data for some of those variables.\n", "\n", "Let us define a simple graph." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "graph = Graph(\"graph\")\n", "with graph:\n", " x = Variable(\"x\", shape=(), mean=0, variance=1**2)\n", " a = StaticVariable(\"a\", shape=(), mean=1, variance=5**2)\n", " b = StaticVariable(\"b\", shape=(), mean=0, variance=5**2)\n", " y = Variable(\"y\", shape=(), mean=a * x + b, variance=1**2) \n", " \n", "\n", "show(graph)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we create some data for training." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "true_slope = 2\n", "true_intercept = 1.5\n", "\n", "x_train_data = np.linspace(-10, 10, 40)\n", "y_train_data = true_slope * x_train_data + true_intercept + np.random.normal(size=x_train_data.shape)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can plot the training data to get a visual impression." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEGCAYAAACO8lkDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAASxElEQVR4nO3de4xc513G8efBzgVIRRK8SdzYwS6YqE4BNVlMgIIqkhIniuI0EOT8QQ0psoJq0UggmmDUblUCvYhKXFrANFFNlTZEakOsqlHjhKByy2Ud5WLXce00MVls7G0DTVFFwMmPP87ZejxnZs/MzpzLnPP9SNbOzJmd/e073vPMed/zvscRIQAAOn1P1QUAAOqHcAAAZBAOAIAMwgEAkEE4AAAyllddwDisWLEi1qxZU3UZADBR9uzZ842ImOq1rRHhsGbNGs3OzlZdBgBMFNuH+22jWwkAkEE4AAAyCAcAQAbhAADIIBwAABmEAwBMsJmZYl6XcACACfbBDxbzuoQDACCDcACACTMzI9nJP+nk7XF2MbkJF/uZnp4OZkgDaCNbWupu3PaeiJjutY0jBwCosaIGnPMQDgBQY3kDzh/4QDE/l3AAgAnGqawA0BJlDDjnYUAaAGpslAHn/NdmQBoAMITKwsH2atuP2N5ve5/t96aPn2t7t+2D6ddzqqoRAKpW1IBzniqPHE5I+u2IeLOkyyW9x/Z6SbdJejgi1kl6OL0PAK3UulNZI+JoRDyZ3v62pP2SLpS0SdLO9Gk7JV1fSYEA0GK1GHOwvUbSWyU9Jun8iDgqJQEi6bw+37PV9qzt2fn5+dJqBYBhVPXJf1SVh4PtsyR9XtKtEfHKoN8XETsiYjoipqempoorEABGUNSqqUWrNBxsn6YkGO6OiC+kDx+zvTLdvlLS8arqA4C2qvJsJUu6U9L+iPh4x6Zdkrakt7dIur/s2gBgFHWYxDaqyibB2X6bpH+U9Kyk19OHf0/JuMO9ki6S9G+SboyIlxd7LSbBAairIiexjWqxSXDLyy5mQUT8kyT32XxFmbUAQFVmZup5RFH5gDQANFneJLa6DlgTDgBQoDoeFQyCcACAkk3CgDXhAAAjGnanPjOTDFIvDFQv3CYcAGCC5O206zpuMArCAQByFLnzr2rV1TyEAwAswbjGDerUldSJcACAHvJ2/pMwbjAKLhMKADnyZjnXeRb0YrhMKAAUqK7jBqMgHAAgR97OvyldSZ0IBwDI0cSdfx7CAQCQQTgAADIIBwBABuEAoPXaOKaQh3AA0HpNXBtpVIQDgMbjyGB4hAOAxut1ZDAJ11SoEstnAGi8pi5/MSqWzwDQOhwZjGZ51QUAQBFmZk4GQd6RQRPXRhoVRw4AJt6oRwMcTWQRDgAmXt6pqBwZDI9wANB4HBkMj3AAMJEYcC4Wp7ICmHhtPRV1VJzKCgAYCuEAYOIx4Dx+hAOAUhQ5FsA4w/gRDgBKwcqnk4VwAABkEA4ACsPpppOLU1kBlILTTeuHU1kBAEMhHACUgtNNJwvhAKAUjDNMlkrDwfZdto/b3tvx2Lm2d9s+mH49p8oaAZSD8KiXqo8cPi1pY9djt0l6OCLWSXo4vQ+g4fLmQRAe5ao0HCLiK5Je7np4k6Sd6e2dkq4vsyYA9cQkunJVfeTQy/kRcVSS0q/nVVwPgIIwD6K+6hgOA7G91fas7dn5+fmqywGwBDMzydyHhfkPC7cXwoHwqE4dw+GY7ZWSlH493utJEbEjIqYjYnpqaqrUAgFkFbHDzgsPFKeO4bBL0pb09hZJ91dYC4ABjTomwDyIeqn6VNbPSfpXSRfbnrP9bkkflvQO2wclvSO9D6Dh8o4GCI9yVX220k0RsTIiTouIVRFxZ0R8MyKuiIh16dfus5kA1ESZYwJ0JZWLhfcAjAUL600eFt4DAAyFcAAwFowJNAvhAGAsGBNoFsIBAJBBOAAAMggHAAOh26hdCAcA37VYALAqarsQDgC+iwDAAsIBQF+sitpehAPQcosFAKuithfhADTIUnbaBAB6IRyABinyOszMgG4XwgFokbzwWCwAOJJoF8IBmCC9dtDjHDQmALCAcAAmSK9P/lyHGUXgeg7ABMm7ZsKo29EuXM8BmGDDfPJn0BjjQjgAJRu2O2eYU025DjPGhW4loGSjdO3QLYRxolsJaAg++aMshANQgnGdMcQZRigL3UpAyegaQl3QrQQAGArhAJSMcQNMAsIBKBnjBpgEhAMAIINwAABkEA5AjdDlhLogHIAaybveAlAWwgEAkEE4ABXjeguoI2ZIAzXC7GmUiRnSAIChEA7AkPK6e0bpDmL2NOoit1vJ9jZJd0fEf5ZT0vDoVkKZuBQnmmLUbqULJD1h+17bG+2FYTMAQFPlhkNE/L6kdZLulPRrkg7a/kPbP1xwbUBt5J1RxBlHaJqBz1ay/ROSfl3SRkmPSLpc0u6I+N1CCrM3SvoTScskfSoiPtzvuXQroUx0K6EpRupWsv1btvdI+qikf5b0YxHxm5Iuk/RLY6305M9cJukTkq6WtF7STbbXF/GzAABZywd4zgpJN0TE4c4HI+J129cWU5Y2SDoUEV+XJNv3SNok6asF/TxgYHlnFHHGEZqglpPgbP+ypI0R8Rvp/V+V9FMRsa3jOVslbZWkiy666LLDhw/3fC0AQG+TOAmu1xlRp6RYROyIiOmImJ6amiqpLABoh7qGw5yk1R33V0k6UlEtANA6dQ2HJySts73W9umSNkvaVXFNaAlOPwVqGg4RcULSNklflrRf0r0Rsa/aqtAUeTt/rqkA1HRAeljMc8AwmKcAJCZxQBooFTOcgVMRDmiFQZa/iDh5xLBwm3BAW9GthNahWwlI0K0EDIEZzgDhgIZarDsob+dPVxJAtxIaiq4hIB/dSgCAoRAOaAxORwXGh24lNBLdSkA+upUAAEMhHNBInI4KjIZwQCMxzgCMhnAAAGQQDgCADMIBAJBBOKCWGDMAqkU4oJa4GhtQLcIBAJBBOKA2WP4CqA+Wz0AtsfwFUDyWz0DjcDQBFItwQC3lLX/BgDVQLMIBlcj75M+RAVAtwgGVWMonfwasgfIwII1KjDrgzIA1MDoGpFELfPIHJgfhgEL02uHPzCSf9hc+8S/cXko4cL0GoFh0K6EQed0+dAsB1aNbCbXDJ3+g3ggHjM0wYwqMMwD1Rjhgybp38OMcUwBQLcIBS8YsZaC5CAcUgjEFYLIRDhjKoOMKdCUBk41TWbFknI4KTDZOZQUADIVwwJIxrgA0F+GAJWNcAWiuSsLB9o2299l+3fZ017bbbR+yfcD2VVXUBwBtt7yin7tX0g2S/qrzQdvrJW2WdImkN0p6yPaPRsRr5ZcIAO1VyZFDROyPiAM9Nm2SdE9EvBoRL0g6JGlDudUBAOo25nChpJc67s+lj2XY3mp71vbs/Px8KcXVTdF9/owpAO1VWDjYfsj23h7/Ni32bT0e63kmfUTsiIjpiJiempoaT9ETpujlK1geA2ivwsYcIuLKJXzbnKTVHfdXSToynooAAIOqW7fSLkmbbZ9he62kdZIer7imWhnnpTb7LaXNpTwBVLJ8hu13SvozSVOS/kvSUxFxVbptu6SbJZ2QdGtEPJD3em1dPmPU5Su4WhvQbostn1HJqawRcZ+k+/psu0PSHeVWBADoVLduJQxhKctXDNNtxPIYQHuxKmuDzcwsPlZAtxHQbqzK2lKcigpgqQiHFqPbCEA/hEPDDDOmwOmpAPphzKHBGFMAsBjGHAAAQyEcGowxBQBLRTg0GGMKAJaKcKg5dvAAqkA41BxzFQBUgXAAAGQQDiMqotuHZbMBVI15DiMqei4BcxUAFIV5DhXi0z6ASUQ4LMEw3T6jDigzVwFAFehWGhFXUwMwqehWKhmL3wGYdBw5jGjUC+pwZAGgKhw5FIhP/gCaiHAoWK8BZeYxAKg7upUqRrcSgKrQrQQAGArhUDHmMQCoI8KhYowzAKgjwgEAkEE4AAAyCAcAQAbhkIMxAQBtRDjk4DKdANqIcAAAZBAOPbC8BYC2Y/mMHCxvAaCpWD4DADAUwiEHy1sAaCPCIQfjDADaiHAAAGQQDgCAjErCwfbHbD9n+xnb99k+u2Pb7bYP2T5g+6oy6qHrCABOVdWRw25Jb4mIH5f0NUm3S5Lt9ZI2S7pE0kZJn7S9rOhimAUNAKeqJBwi4sGIOJHefVTSqvT2Jkn3RMSrEfGCpEOSNlRRIwC0WR3GHG6W9EB6+0JJL3Vsm0sfy7C91fas7dn5+fmhfyizoAGgv+VFvbDthyRd0GPT9oi4P33OdkknJN298G09nt9zfnJE7JC0Q0pmSA9b38zMySBgFjQAnKqwcIiIKxfbbnuLpGslXREn1/CYk7S642mrJB0ppkIAQD9Vna20UdL7JF0XEd/p2LRL0mbbZ9heK2mdpMeLrodZ0ABwqsKOHHL8uaQzJO120un/aETcEhH7bN8r6atKupveExGvFV0M4wwAcKpKwiEifmSRbXdIuqPEcgAAXepwthIAoGYIBwBABuEAAMggHAAAGY24TKjteUmHR3iJFZK+MaZyxom6hkNdw6Gu4TSxrh+KiKleGxoRDqOyPdvvOqpVoq7hUNdwqGs4bauLbiUAQAbhAADIIBwSO6ouoA/qGg51DYe6htOquhhzAABkcOQAAMggHAAAGa0IB9s32t5n+3Xb013bbrd9yPYB21f1+f5zbe+2fTD9ek5Bdf6t7afSfy/afqrP8160/Wz6vNkiaun6eTO2/72jtmv6PG9j2o6HbN9WQl0fs/2c7Wds32f77D7PK7y98n53J/403f6M7UuLqKPHz11t+xHb+9O/gff2eM7bbX+r4/19f0m1Lfq+VNFmti/uaIenbL9i+9au55TSXrbvsn3c9t6OxwbaF43lbzEiGv9P0pslXSzpHyRNdzy+XtLTSpYPXyvpeUnLenz/RyXdlt6+TdJHSqj5jyW9v8+2FyWtKLH9ZiT9Ts5zlqXt9yZJp6ftur7gun5R0vL09kf6vS9Ft9cgv7uka5RcDteSLpf0WEnv3UpJl6a33yDpaz1qe7ukL5b1/2nQ96WqNut6X/9DyUSx0ttL0s9LulTS3o7HcvdF4/pbbMWRQ0Tsj4gDPTZtknRPRLwaES9IOiRpQ5/n7Uxv75R0fSGFppxc5OJXJH2uyJ8zZhskHYqIr0fE/0q6R0m7FSYiHoyIE+ndR5VcObAKg/zumyT9TSQelXS27ZVFFxYRRyPiyfT2tyXtV5/rstdQJW3W4QpJz0fEKKsvLFlEfEXSy10PD7IvGsvfYivCYREXSnqp4/6cev/hnB8RR6Xkj03SeQXX9XOSjkXEwT7bQ9KDtvfY3lpwLQu2pYf2d/U5lB20LYtys5JPmb0U3V6D/O5Vt49sr5H0VkmP9dj807aftv2A7UtKKinvfam6zTar/we0KtpLGmxfNJZ2q+pKcGNn+yFJF/TYtD0i7u/3bT0eK/Tc3gHrvEmLHzX8bEQcsX2ekqvpPZd+yiikLkl/IelDStrmQ0q6vG7ufoke3ztyWw7SXra3K7ly4N19Xmbs7dVdZo/Hun/30v+vnfLD7bMkfV7SrRHxStfmJ5V0nfx3Op70d0ou0Vu0vPelsjazfbqk6yTd3mNzVe01qLG0W2PCISKuXMK3zUla3XF/laQjPZ53zPbKiDiaHtYeX0qNUn6dtpdLukHSZYu8xpH063Hb9yk5jBxpZzdo+9n+a0lf7LFp0LYca122t0i6VtIVkXa49niNsbdXl0F+90LaZxC2T1MSDHdHxBe6t3eGRUR8yfYnba+IiEIXmRvgfamszSRdLenJiDjWvaGq9koNsi8aS7u1vVtpl6TNts+wvVZJ+j/e53lb0ttbJPU7EhmHKyU9FxFzvTba/n7bb1i4rWRQdm+v545LVz/vO/v8vCckrbO9Nv3UtVlJuxVZ10ZJ75N0XUR8p89zymivQX73XZLelZ6Bc7mkby10DxQpHb+6U9L+iPh4n+dckD5Ptjco2S98s+C6BnlfKmmzVN+j9yraq8Mg+6Lx/C0WPeJeh39Kdmhzkl6VdEzSlzu2bVcysn9A0tUdj39K6ZlNkn5Q0sOSDqZfzy2w1k9LuqXrsTdK+lJ6+01Kzj54WtI+Jd0rRbffZyQ9K+mZ9D/Zyu660vvXKDkb5vmS6jqkpG/1qfTfX1bVXr1+d0m3LLyXSg71P5Fuf1YdZ80V3EZvU9Kl8ExHO13TVdu2tG2eVjKw/zMl1NXzfalJm32fkp39D3Q8Vnp7KQmno5L+L91/vbvfvqiIv0WWzwAAZLS9WwkA0APhAADIIBwAABmEAwAgg3AAAGQQDgCADMIBAJBBOAAFsP2T6UKFZ6azgffZfkvVdQGDYhIcUBDbfyDpTEnfK2kuIv6o4pKAgREOQEHSdW2ekPQ/SpZYeK3ikoCB0a0EFOdcSWcpuQLbmRXXAgyFIwegILZ3KbkK11olixVuq7gkYGCNuZ4DUCe23yXpRER81vYySf9i+xci4u+rrg0YBEcOAIAMxhwAABmEAwAgg3AAAGQQDgCADMIBAJBBOAAAMggHAEDG/wPnzj4HnCS6RAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(x_train_data, y_train_data, '+b');\n", "plt.xlabel('x');\n", "plt.ylabel('y');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When using model factories, the training data can be passed as a dictionary specifying the association of the variables in the graph and the data." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "train_data = {graph.x: x_train_data, graph.y: y_train_data}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training with model factories\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create and train a model, one can employ the `get_posterior_model` function. \n", "The data for training can be directly provided as a dictionary." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "model = get_posterior_model(graph=graph, data=train_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Unless specified otherwise, the model is already solved, and one can directly obtain the posterior graph." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "trained_graph = model.get_posterior_graph()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training models directly\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Models can also be created and trained using the Halerium.model classes and their methods directly.\n", "\n", "First, we pick a model class suitable for training, e.g. a MAPModel. Then we create a model instance of that class. For that, we have to provide the graph and the data packaged into a data linker." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "model = MAPModel(graph=graph,\n", " data=train_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can train the model using the `solve` method." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "model.solve()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As a result, the model has adjusted the distribution of the model parameters `graph.a` and `graph.b` according to the training data." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trained mean for a: 2.0128642238484\n", "trained mean for b: 1.4387307591662526\n" ] } ], "source": [ "a_trained_mean = model.get_means(graph.a)\n", "b_trained_mean = model.get_means(graph.b)\n", "\n", "print(\"trained mean for a:\", a_trained_mean)\n", "print(\"trained mean for b:\", b_trained_mean)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now obtain a trained graph." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "trained_graph = model.get_posterior_graph()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Outlook: Using a trained model for predictions\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The trained graph can be used to compute predictions, e.g., for `graph.y` given new data for `graph.x`.\n", "We create new data for `graph.x`, a new model with the trained graph and the new data as input data, solve that model, and extract values from the model for `graph.y` given the new data for `graph.x`." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "predicted values for y: [-18.68991148 -16.67704726 -14.66418303 -12.65131881 -10.63845458\n", " -8.62559036 -6.61272614 -4.59986191 -2.58699769 -0.57413346\n", " 1.43873076 3.45159498 5.46445921 7.47732343 9.49018765\n", " 11.50305188 13.5159161 15.52878033 17.54164455 19.55450877\n", " 21.567373 ]\n" ] } ], "source": [ "x_prediction_data = np.linspace(-10, 10, 21)\n", "\n", "prediction_input_data = {graph.x: x_prediction_data}\n", "\n", "trained_model = MAPModel(graph=trained_graph,\n", " data=prediction_input_data)\n", "\n", "trained_model.solve()\n", "y_prediction_data = trained_model.get_means(graph.y)\n", "print(\"predicted values for y:\", y_prediction_data)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(x_train_data, y_train_data, '+b');\n", "plt.plot(x_prediction_data, y_prediction_data, '.r');\n", "plt.xlabel('x');\n", "plt.ylabel('y');\n", "plt.legend(['training data', 'predictions']);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }