{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Probability Estimation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "# execute the creation & training notebook first\n", "%run \"02-01-creation_and_training.ipynb\"\n", "# execute the outlier detection notebook\n", "%run \"02-05-outlier_detection.ipynb\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the [outlier detection section](./02-05-outlier_detection.ipynb) we saw how to detect outliers in a test data set and how the outlier threshold influenced the detection results.\n", "In the [rank estimation section](./02-07-rank_estimation.ipynb) the underlying method of rank estimation was explained.\n", "\n", "In this section we take a look at the ``.estimate_probabilities`` method, which calls the ``ProbabilityEstimator``, the basis for both outlier detection and rank estimation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The ``ProbabilityEstimator`` is used by the ``RankEstimator`` under the hood to estimate the probability of individual data points. Let us apply the probability estimator to the modified test data set from the outlier detection example." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
(a)(b|a)(c|a,b)graph
01.222395-1.6029531.2370680.855358
11.413743-1.0123561.3236331.725051
21.113812-0.9944861.2446451.363961
30.006604-1.2193090.964815-0.252899
41.405377-0.9840491.3227101.744055
...............
950.431912-1.0190230.9860770.398876
961.329171-1.1790341.2799771.429986
971.343442-1.0469091.3102681.606940
981.408012-1.0115261.3157081.712243
991.405317-1.4224811.3244291.307159
\n", "

100 rows × 4 columns

\n", "
" ], "text/plain": [ " (a) (b|a) (c|a,b) graph\n", "0 1.222395 -1.602953 1.237068 0.855358\n", "1 1.413743 -1.012356 1.323633 1.725051\n", "2 1.113812 -0.994486 1.244645 1.363961\n", "3 0.006604 -1.219309 0.964815 -0.252899\n", "4 1.405377 -0.984049 1.322710 1.744055\n", ".. ... ... ... ...\n", "95 0.431912 -1.019023 0.986077 0.398876\n", "96 1.329171 -1.179034 1.279977 1.429986\n", "97 1.343442 -1.046909 1.310268 1.606940\n", "98 1.408012 -1.011526 1.315708 1.712243\n", "99 1.405317 -1.422481 1.324429 1.307159\n", "\n", "[100 rows x 4 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_probabilities = causal_structure.estimate_probabilities(data=mod_test_data)\n", "log_probabilities" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The numbers that are returned are the logarithmic probability densities. These values are unbound. They can assume floating point value.\n", "\n", "Since they are densities their individual values do not tell us much about the likelihood of a given event. Only the comparison to other probability density values for the same parameter provides some insight.\n", "\n", "For example let us extract the 5 data points with the lowest probability for the whole graph." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
(a)(b|a)(c|a,b)graph
500.987680-13.097823-245.485737-262.044845
600.766149-1.219600-24.933954-25.538929
74-1.950167-1.3594560.267683-3.049903
79-0.471609-2.7570130.663875-2.556750
13-0.325763-2.4609740.718113-2.064678
\n", "
" ], "text/plain": [ " (a) (b|a) (c|a,b) graph\n", "50 0.987680 -13.097823 -245.485737 -262.044845\n", "60 0.766149 -1.219600 -24.933954 -25.538929\n", "74 -1.950167 -1.359456 0.267683 -3.049903\n", "79 -0.471609 -2.757013 0.663875 -2.556750\n", "13 -0.325763 -2.460974 0.718113 -2.064678" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_probabilities.sort_values(\"graph\").iloc[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that the modified data points 50 and 60 indeed have very low logarithmic probability densities. Much lower than any of the unmodified data points. This is why they were easily picked up by the ``OutlierDetector``." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One more thing is noticable when looking at the logarithmic probability densities.\n", "The sum of the '(a)', '(b|a)' and '(c|a,b)' columns seems to be close to the 'graph' column." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
(a)(b|a)(c|a,b)graphsum of '(a)', '(b|a)', '(c|a,b)'
500.987680-13.097823-245.485737-262.044845-257.595880
600.766149-1.219600-24.933954-25.538929-25.387404
74-1.950167-1.3594560.267683-3.049903-3.041940
79-0.471609-2.7570130.663875-2.556750-2.564747
13-0.325763-2.4609740.718113-2.064678-2.068623
\n", "
" ], "text/plain": [ " (a) (b|a) (c|a,b) graph \\\n", "50 0.987680 -13.097823 -245.485737 -262.044845 \n", "60 0.766149 -1.219600 -24.933954 -25.538929 \n", "74 -1.950167 -1.359456 0.267683 -3.049903 \n", "79 -0.471609 -2.757013 0.663875 -2.556750 \n", "13 -0.325763 -2.460974 0.718113 -2.064678 \n", "\n", " sum of '(a)', '(b|a)', '(c|a,b)' \n", "50 -257.595880 \n", "60 -25.387404 \n", "74 -3.041940 \n", "79 -2.564747 \n", "13 -2.068623 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_probability_sums = log_probabilities[['(a)', '(b|a)', '(c|a,b)']].sum(axis=1)\n", "log_probabilities_with_sum = log_probabilities.copy()\n", "log_probabilities_with_sum[\"sum of '(a)', '(b|a)', '(c|a,b)'\"] = log_probability_sums\n", "log_probabilities_with_sum.sort_values(\"graph\").iloc[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is not by accident. The [chain rule](https://en.wikipedia.org/wiki/Chain_rule_%28probability%29) (or product rule) of conditional probabilities tells us that the product of the (correctly conditioned) probability densities is the total probability density. The graph column represents the total logarithmic probability density and for logarithmic probability densities the product becomes a sum.\n", "\n", "The reason why the sum and the graph column do not match exactly is that the graph also contains the probabilites of the trained regression parameters, which enter the total sum." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For further details about the ``ProbabilityEstimator`` see the [corresponding section](../02_objectives/01_probability_estimator.ipynb) in the core-documentation." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 }