{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Performance Evaluation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "# execute the creation & training notebook first\n", "%run \"02-01-creation_and_training.ipynb\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After [training](./02-01-creation_and_training.ipynb) we might want to know how well our causal structure can predict data.\n", "We can evaluate the predictive power of our causal structure using the ``.evaluate`` method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this we will need some test data. We create these artifical data now." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "np.random.seed(42)\n", "n_data = 100\n", "parameter_a = 5 + np.random.randn(n_data) * 0.1\n", "parameter_b = parameter_a * (-35) + 150 + np.random.randn(n_data) * 1.\n", "parameter_c = parameter_a * 10.5 + parameter_b * (.5) + np.random.randn(n_data) * 0.01\n", "\n", "test_data = pd.DataFrame(data={\"(a)\": parameter_a,\n", " \"(b|a)\": parameter_b,\n", " \"(c|a,b)\": parameter_c})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Apart from the test data, we will have to specify which parameter(s) serve as prediction inputs. Let's start with '(a)' being the only input. This means we will evaluate the performance of predicting the other parameters values from the value of '(a)'." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(a) NaN\n", "(b|a) 0.927296\n", "(c|a,b) 0.697470\n", "Name: r2, dtype: float64" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluation = causal_structure.evaluate(data=test_data,\n", " inputs=[\"(a)\"])\n", "evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default ``.evaluate`` evaluates the R2-score.\n", "We see that based on '(a)' the model achieves a prediction score of ~0.9 on '(b|a)' and ~0.5 on '(c|a,b)'. For '(a)' we get the answer ``NaN``, since it was part of the inputs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we change the inputs to '(a)' and '(b|a)' we expect an increase in the score for '(c|a,b)'.\n", "We cann also pass further arguments to the ``evaluate`` method. Say next to the R2-score we want to know the root mean square error (\"rmse\")." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
(a)(b|a)(c|a,b)
r2NoneNone0.998541
rmseNoneNone0.032054
\n", "
" ], "text/plain": [ " (a) (b|a) (c|a,b)\n", "r2 None None 0.998541\n", "rmse None None 0.032054" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluation = causal_structure.evaluate(data=test_data,\n", " inputs=[\"(a)\", \"(b|a)\"],\n", " metric=(\"r2\", \"rmse\"))\n", "evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected the R2-score for '(c|a,b)' increased significantly. Additionally we have the root mean squared error, which is very low." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For further details about the ``Evaluator`` see the [corresponding section](../02_objectives/04_evaluator.ipynb) in the core-documentation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the [next section](./02-05-outlier_detection.ipynb) we will have a look at outlier detection." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 }