{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Simple training with the Trainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In Halerium, a model with data for some of its variables usually needs to be \"solved\" before it can be employed to compute statistical properties consistent with the provided data. This solving step may also be referred to as \"fitting the model to the data\" (which is roughly what happens under the hood) or (in particular in the context of machine learning) as \"training\".\n", "\n", "The simplest way to train a model is to use the `Trainer` class.\n", "\n", "The result of the training is the posterior graph representing the trained model. This graph can be used to compute predictions from the trained model, see [Outlook](#predictions), and [Predictor](./../02_objectives/03_predictor.ipynb)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to import the following packages, classes, and functions." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# for handling data:\n", "import numpy as np\n", "\n", "# for plotting\n", "import matplotlib.pyplot as plt\n", "\n", "# for graphs:\n", "from halerium.core import Graph, Variable, StaticVariable, show\n", "\n", "# for training models\n", "from halerium.core.model import Trainer\n", "\n", "# for predictions\n", "from halerium import Predictor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The graph and data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For training models, we need a graph representing the prior statistical properties and connections between variables, and data for some of those variables.\n", "\n", "Let us define a simple graph." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "graph = Graph(\"graph\")\n", "with graph:\n", " x = Variable(\"x\", shape=(), mean=0, variance=1**2)\n", " a = StaticVariable(\"a\", shape=(), mean=1, variance=5**2)\n", " b = StaticVariable(\"b\", shape=(), mean=0, variance=5**2)\n", " y = Variable(\"y\", shape=(), mean=a * x + b, variance=1**2) \n", " \n", "show(graph)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we create some data for training." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "true_slope = 2\n", "true_intercept = 1.5\n", "\n", "x_train_data = np.linspace(-10, 10, 40)\n", "y_train_data = true_slope * x_train_data + true_intercept + np.random.normal(size=x_train_data.shape)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can plot the training data to get a visual impression." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEGCAYAAACO8lkDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAR4ElEQVR4nO3da4xc9X3G8ecp5tKWqEC9gIOhdiIX4aYvAltEmraKCmkARTFJRQUvGldBslIFNUiNVFMqmCi9JFTNi0rpxS0oTkWhqAnFolACNFXUSwhrxM01rk2CyxbHdpI2pKpK6vDri3M2GebM7JnZmTP/c/l+JGtnzpnd/e0Z73n2fzvHESEAAPr9QOoCAAD1QzgAAAoIBwBAAeEAACggHAAABetSFzAL69evj02bNqUuAwAaZe/evV+PiIVh+1oRDps2bdLS0lLqMgCgUWwfHrWPbiUAQAHhAAAoIBwAAAWEAwCggHAAABQQDgDQYL1eNV+XcACABvvoR6v5uoQDAKCAcACAhun1JDv7J33/8Sy7mNyGm/0sLi4GK6QBdJEtrfU0bntvRCwO20fLAQBQQDgAQIPddls1X5dwAIAGYyorAHRQVSf/MoQDANRYVesYyhAOAIACwgEAamYe6xjKsM4BAGpsmnUM5V+bdQ4AgAkQDgBQY1WtYyhDOABAjTGVFQBQG8nCwfb5tr9ge7/tfbY/nG8/y/Yjtg/mH89MVSMAdFXKlsMJSb8eERdJukzSh2xvlbRT0mMRsUXSY/lzAGilVN1GZZKFQ0QciYgn88fflrRf0nmStknanb9st6RrkhQIAHOQagV0mVqMOdjeJOmtkh6XdE5EHJGyAJF09ojP2WF7yfbS8ePH51YrAHRB8nCwfbqkz0q6KSJeGffzImJXRCxGxOLCwkJ1BQLAjNVhBXSZpOFg+2RlwXBXRHwu33zU9oZ8/wZJx1LVBwBV6PWyVc8rK59XHhMOkmxb0h2S9kfEJ/t27ZG0PX+8XdL9864NALpuXcLv/XZJvyzpWdtP5dt+U9LHJd1r+wZJ/y7p2jTlAUD1Uq2ALpMsHCLiHyV5xO7L51kLAFSl11u9u6hOXUn9kg9IA0Cb1XWqahnCAQBQQDgAwIw1YapqGW72AwAVqvJmPdPiZj8AgIkQDgAwpdW6i+o6VbUM3UoAMKU6dx2thm4lAMBECAcAWIM2zEhaDd1KADAlupUAoIPa0hqYBOEAACXKLoHR1BlJqyEcAGBKbWxZEA4AMETbB5zLMCANACWaOuBchgFpAMBECAcAjVd1V08bB5zL0K0EoPHa2u1TNbqVAAATIRwANNIsZxN1ZQbSJOhWAtB403YrdbVbim4lAMBECAcAjbeW2URdX+RWhm4lAJ1Ht1IRLQcAQAHhAKDzurjIrQzhAKDzGGcoIhwAtB4n/8kRDgBar+xmPSgiHAAABYQDgFZiHcN0WOcAoPW6uo6hDOscAAATIRwAtB7rGCZHOABohGnGChhnmBzhAKARmI46X4QDAKCAcABQW0xHTYdwAFALw074vV42BXVlGurKY8KheknDwfadto/Zfq5v21m2H7F9MP94ZsoaAcwHYwr1krrl8GlJVw5s2ynpsYjYIumx/DmAhpv2r32mo85X0nCIiC9K+ubA5m2SduePd0u6Zp41AajGsJbBJGMKdCXNV/LLZ9jeJOmBiHhL/vy/IuKMvv3/GRGFriXbOyTtkKQLLrjgksOHD8+nYABrUnYJCy5xMX+tvHxGROyKiMWIWFxYWEhdDoAhmG3UXHUMh6O2N0hS/vFY4noAjGHa2UaMKdRLHcNhj6Tt+ePtku5PWAuAMU0724jWRL2knsp6t6R/kXSh7WXbN0j6uKR32j4o6Z35cwANR8ugWVLPVro+IjZExMkRsTEi7oiIb0TE5RGxJf84OJsJQE0w26i9ks9WmgVu9gOkx2yj5mnlbCUAQHUIBwAzwZhCuxAOAGaCMYV2IRwAAAWEA4DvWe2vf1oG3UI4AB1SdoJfbSEbl9TuFsIB6BBO8BgX4QB03GoL2bhwXnexCA5ouV5veIvhttuKJ/nVFrKxyK19WAQHdAT3YcasEA5Ai0w7prDaQjYWuXUL4QB0SNkJnqmsWEE4AA3HlVFRBcIBaBDGFDAvhAPQIKxTwLwQDkCLMGiMWSEcgJpjTAEpsAgOaBAWomGWWAQHAJgI4QA0CGMKmBfCAWgQxhQwL4QDAKCAcAAAFBAOAIACwgEAUEA4AAAKSsPB9o22z5xHMQCAehin5XCupCds32v7SntlET8AoK1KwyEifkvSFkl3SPoVSQdt/67tN1dcGwAgkbHGHCK7ANPX8n8nJJ0p6a9t315hbUDnsMgNdTHOmMOv2d4r6XZJ/yTpJyPiVyVdIukXK64P6BTu14C6WDfGa9ZLel9EHO7fGBGv2X53NWUBAFIaZ8zh1sFg6Nu3f/YlAc02adfQJPdrAOaF+zkAM1Z2z4Veb/SJn/s1YJ64nwNQI4wroAkIB2AGZtU1xP0aUBeEAzADvV7WHbTSJbTyeCUcxg0PxhlQF4w5ADNWNm7AuALqopFjDvmlOg7YPmR7Z+p6gHHRNYQ2qGU42D5J0qckXSVpq6TrbW9NWxWQKev6KdtPeKAJahkOki6VdCgivhIR35F0j6RtiWsCJE0/24hxBTRBXcPhPEkv9T1fzrd9j+0dtpdsLx0/fnyuxQFA29U1HIZdFvx1Q3gRsSsiFiNicWFhYU5loatYxYyuGefaSiksSzq/7/lGSS8nqgV43apmZhuhC+racnhC0hbbm22fIuk6SXsS14SW4K99oFwtwyEiTki6UdLDkvZLujci9qWtCm0x7YAys43QBXXtVlJEPCjpwdR1AINoeaALatlyAGaNAWVgMlw+A53DgDKQaeTlMwAA6RAO6JyyAWW6mgC6lYACup3QFXQrAQAmQjiglSbtGmI2E/B6dCuhlabpGqJbCV1BtxIAYCKEA1pjVl1DXB4DoFsJLUXXEFCObiUAwEQIB7QSXUPAdAgHtBJTUIHpEA5oJE7+QLUIBzTStDfsAbA6wgEAUEA4oDG4xAUwP6xzQCOxjgGYHuscAAATIRxQS2VdRaxjAKpFtxJqiW4joHp0K6F2GEQG6o1wQBLD1ikwGwmoD7qVkERZtxHdSkD16FZCLdAyAJpjXeoC0B293veDoKxlwGwkIC1aDqglWhNAWoQDkqBlANQb4YAkaBkA9UY4AAAKCAdUgpYB0GyEAyrBzXiAZiMcAAAFhANmhkVuQHtw+QxUgstfAPXH5TNQCVoEQHsRDliz1QadWeQGNBvhgErQqgCaLUk42L7W9j7br9leHNh3s+1Dtg/YfleK+pAZdoJn0BnohiQD0rYvkvSapD+V9JGIWMq3b5V0t6RLJb1R0qOSfjwivrva12NAuhrccwFot9oNSEfE/og4MGTXNkn3RMSrEfFVSYeUBQUAYI7qNuZwnqSX+p4v59sKbO+wvWR76fjx43Mprgsm6TZi0Blor8pu9mP7UUnnDtl1S0TcP+rThmwb2nEREbsk7ZKybqU1FYmCSW7IwzgD0F6VhUNEXLGGT1uWdH7f842SXp5NRRjUHwQA0K9u3Up7JF1n+1TbmyVtkfTlxDXV1rQn9rKL49FtBHRXqqms77W9LOltkv7W9sOSFBH7JN0r6V8l/Z2kD5XNVOqyqq98SqsC6K5Us5Xui4iNEXFqRJwTEe/q2/c7EfHmiLgwIh5KUV9bsE4BwFrVrVsJJSY5uQ9rWfR62SDzykDzymPCAUA/rsraYNMuUmMRG9BttVsEh+qwTgHALNByaLCyqai0DACshpZDSzFOAKAqhEOL0W0EYK0IhxajZQFgrQgHAEAB4VBz/PUPIAXCoeaqvkQGAAxDOAAACgiHGuL6RwBSYxFczbGQDUBVWAQHAJgI4ZBYWVcRC9kApEC3UmJ0GwFIhW4lAMBECIcEmI0EoO7oVkqMbiUAqdCtBACYCOGQGLORANQR4ZAY4wwA6ohwAAAUEA4Vo2UAoIkIhxLTnty55DaAJiIcSpSd3GkZAGgjwmFKw8KDRW4Amo5wGGLak3uvly1sW1nctvKYcADQFISDiiftspM7LQMAbcflM7T6JSzKLm9Rtr/XIzQA1BOXz5jCtCuYCQYATdTZcBi3a4ib8QDoIrqVxJVRAXQT3UoAgIkQDqJrCAAGEQ5i0BgABhEOAIACwgEAUEA4AAAKCAcAQAHhAAAoaMUiONvHJR2e4kusl/T1GZUzS9Q1GeqaDHVNpo11/VhELAzb0YpwmJbtpVGrBFOirslQ12SoazJdq4tuJQBAAeEAACggHDK7UhcwAnVNhromQ12T6VRdjDkAAApoOQAACggHAEBBJ8LB9rW299l+zfbiwL6bbR+yfcD2u0Z8/lm2H7F9MP94ZkV1/pXtp/J/L9p+asTrXrT9bP66td/laPy6erb/o6+2q0e87sr8OB6yvXMOdf2+7edtP2P7PttnjHhd5cer7Gd35g/z/c/YvriKOoZ83/Ntf8H2/vx34MNDXvMO29/qe39vnVNtq74vKY6Z7Qv7jsNTtl+xfdPAa+ZyvGzfafuY7ef6to11LprJ72JEtP6fpIskXSjpHyQt9m3fKulpSadK2izpBUknDfn82yXtzB/vlPSJOdT8B5JuHbHvRUnr53j8epI+UvKak/Lj9yZJp+THdWvFdf2CpHX540+Mel+qPl7j/OySrpb0kCRLukzS43N67zZIujh//AZJ/zaktndIemBe/5/GfV9SHbOB9/VryhaKzf14Sfo5SRdLeq5vW+m5aFa/i51oOUTE/og4MGTXNkn3RMSrEfFVSYckXTridbvzx7slXVNJoTnblvRLku6u8vvM2KWSDkXEVyLiO5LuUXbcKhMRn4+IE/nTL0naWOX3W8U4P/s2SZ+JzJcknWF7Q9WFRcSRiHgyf/xtSfslnVf1952RJMesz+WSXoiIaa6+sGYR8UVJ3xzYPM65aCa/i50Ih1WcJ+mlvufLGv6Lc05EHJGyXzZJZ1dc189KOhoRB0fsD0mft73X9o6Ka1lxY960v3NEU3bcY1mVDyj7K3OYqo/XOD976uMj25skvVXS40N2v83207Yfsv0Tcyqp7H1Jfcyu0+g/0FIcL2m8c9FMjtu6NZVXQ7YflXTukF23RMT9oz5tyLZK5/aOWef1Wr3V8PaIeNn22ZIesf18/ldGJXVJ+mNJH1N2bD6mrMvrA4NfYsjnTn0sxzletm+RdELSXSO+zMyP12CZQ7YN/uxz/7/2um9uny7ps5JuiohXBnY/qazr5L/z8aS/kbRlDmWVvS/JjpntUyS9R9LNQ3anOl7jmslxa004RMQVa/i0ZUnn9z3fKOnlIa87antDRBzJm7XH1lKjVF6n7XWS3ifpklW+xsv5x2O271PWjJzqZDfu8bP9Z5IeGLJr3GM507psb5f0bkmXR97hOuRrzPx4DRjnZ6/k+IzD9snKguGuiPjc4P7+sIiIB23/ke31EVHpRebGeF+SHTNJV0l6MiKODu5Idbxy45yLZnLcut6ttEfSdbZPtb1ZWfp/ecTrtuePt0sa1RKZhSskPR8Ry8N22v5h229YeaxsUPa5Ya+dlYF+3veO+H5PSNpie3P+V9d1yo5blXVdKek3JL0nIv5nxGvmcbzG+dn3SHp/PgPnMknfWukeqFI+fnWHpP0R8ckRrzk3f51sX6rsvPCNiusa531JcsxyI1vvKY5Xn3HORbP5Xax6xL0O/5Sd0JYlvSrpqKSH+/bdomxk/4Ckq/q2/7nymU2SflTSY5IO5h/PqrDWT0v64MC2N0p6MH/8JmWzD56WtE9Z90rVx+8vJD0r6Zn8P9mGwbry51crmw3zwpzqOqSsb/Wp/N+fpDpew352SR9ceS+VNfU/le9/Vn2z5io+Rj+jrEvhmb7jdPVAbTfmx+ZpZQP7Pz2Huoa+LzU5Zj+k7GT/I33b5n68lIXTEUn/l5+/bhh1Lqrid5HLZwAACrrerQQAGIJwAAAUEA4AgALCAQBQQDgAAAoIBwBAAeEAACggHIAK2P6p/EKFp+WrgffZfkvquoBxsQgOqIjt35Z0mqQflLQcEb+XuCRgbIQDUJH8ujZPSPpfZZdY+G7ikoCx0a0EVOcsSacruwPbaYlrASZCywGoiO09yu7CtVnZxQpvTFwSMLbW3M8BqBPb75d0IiL+0vZJkv7Z9s9HxN+nrg0YBy0HAEABYw4AgALCAQBQQDgAAAoIBwBAAeEAACggHAAABYQDAKDg/wGJ+vlSdx5lkAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(x_train_data, y_train_data, '+b');\n", "plt.xlabel('x');\n", "plt.ylabel('y');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `Trainer` expects the training data in the form of a dictionary with variables in the graph as keys and the data for these variables as values." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "train_data = {graph.x: x_train_data, graph.y: y_train_data}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training with trainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `Trainer` is instantiated with the untrained graph and the training data. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(graph=graph,\n", " data=train_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The instance can then be called to obtain a trained graph." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "trained_graph = trainer()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Outlook - Computing predictions\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predictions can be computed using a `Predictor` ([predictor.ipynp](./../objectives/predictor.ipynb)) with the trained graph and prediction input data provided upon initialization.\n", "\n", "Let's create some input data for the predictions." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "x_prediction_data = np.linspace(-10, 10, 21)\n", "prediction_input_data = {graph.x: x_prediction_data}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we create a `Predictor` instance." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "predictor = Predictor(graph=trained_graph,\n", " data=prediction_input_data,\n", " method=\"MAP\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We call the predictor with the variable for which we want to have predictions, i.e. `graph.y` in our case." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "y_prediction_data = predictor(graph.y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can plot these predictions on top of the training data." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(x_train_data, y_train_data, '+b');\n", "plt.plot(x_prediction_data, y_prediction_data, '.r');\n", "plt.xlabel('x');\n", "plt.ylabel('y');\n", "plt.legend(['training data', 'predictions']);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }