{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Causal Structures - applied on the California School data set"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pydataset\n",
"import pylab as pl\n",
"\n",
"from halerium import CausalStructure\n",
"from halerium import Evaluator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### The California Test Score Data Set"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example we apply the `CausalStructure` class to the California Test Score Data Set. The preferred data format to use with `CausalStructure` is the pandas `DataFrame`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
distcod
\n",
"
county
\n",
"
district
\n",
"
grspan
\n",
"
enrltot
\n",
"
teachers
\n",
"
calwpct
\n",
"
mealpct
\n",
"
computer
\n",
"
testscr
\n",
"
compstu
\n",
"
expnstu
\n",
"
str
\n",
"
avginc
\n",
"
elpct
\n",
"
readscr
\n",
"
mathscr
\n",
"
\n",
" \n",
" \n",
"
\n",
"
1
\n",
"
75119
\n",
"
Alameda
\n",
"
Sunol Glen Unified
\n",
"
KK-08
\n",
"
195
\n",
"
10.900000
\n",
"
0.510200
\n",
"
2.040800
\n",
"
67
\n",
"
690.799988
\n",
"
0.343590
\n",
"
6384.911133
\n",
"
17.889910
\n",
"
22.690001
\n",
"
0.000000
\n",
"
691.599976
\n",
"
690.000000
\n",
"
\n",
"
\n",
"
2
\n",
"
61499
\n",
"
Butte
\n",
"
Manzanita Elementary
\n",
"
KK-08
\n",
"
240
\n",
"
11.150000
\n",
"
15.416700
\n",
"
47.916698
\n",
"
101
\n",
"
661.200012
\n",
"
0.420833
\n",
"
5099.380859
\n",
"
21.524664
\n",
"
9.824000
\n",
"
4.583333
\n",
"
660.500000
\n",
"
661.900024
\n",
"
\n",
"
\n",
"
3
\n",
"
61549
\n",
"
Butte
\n",
"
Thermalito Union Elementary
\n",
"
KK-08
\n",
"
1550
\n",
"
82.900002
\n",
"
55.032299
\n",
"
76.322601
\n",
"
169
\n",
"
643.599976
\n",
"
0.109032
\n",
"
5501.954590
\n",
"
18.697226
\n",
"
8.978000
\n",
"
30.000002
\n",
"
636.299988
\n",
"
650.900024
\n",
"
\n",
"
\n",
"
4
\n",
"
61457
\n",
"
Butte
\n",
"
Golden Feather Union Elementary
\n",
"
KK-08
\n",
"
243
\n",
"
14.000000
\n",
"
36.475399
\n",
"
77.049202
\n",
"
85
\n",
"
647.700012
\n",
"
0.349794
\n",
"
7101.831055
\n",
"
17.357143
\n",
"
8.978000
\n",
"
0.000000
\n",
"
651.900024
\n",
"
643.500000
\n",
"
\n",
"
\n",
"
5
\n",
"
61523
\n",
"
Butte
\n",
"
Palermo Union Elementary
\n",
"
KK-08
\n",
"
1335
\n",
"
71.500000
\n",
"
33.108601
\n",
"
78.427002
\n",
"
171
\n",
"
640.849976
\n",
"
0.128090
\n",
"
5235.987793
\n",
"
18.671329
\n",
"
9.080333
\n",
"
13.857677
\n",
"
641.799988
\n",
"
639.900024
\n",
"
\n",
"
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
\n",
"
\n",
"
416
\n",
"
68957
\n",
"
San Mateo
\n",
"
Las Lomitas Elementary
\n",
"
KK-08
\n",
"
984
\n",
"
59.730000
\n",
"
0.101600
\n",
"
3.556900
\n",
"
195
\n",
"
704.300049
\n",
"
0.198171
\n",
"
7290.338867
\n",
"
16.474134
\n",
"
28.716999
\n",
"
5.995935
\n",
"
700.900024
\n",
"
707.700012
\n",
"
\n",
"
\n",
"
417
\n",
"
69518
\n",
"
Santa Clara
\n",
"
Los Altos Elementary
\n",
"
KK-08
\n",
"
3724
\n",
"
208.479996
\n",
"
1.074100
\n",
"
1.503800
\n",
"
721
\n",
"
706.750000
\n",
"
0.193609
\n",
"
5741.462891
\n",
"
17.862625
\n",
"
41.734108
\n",
"
4.726101
\n",
"
704.000000
\n",
"
709.500000
\n",
"
\n",
"
\n",
"
418
\n",
"
72611
\n",
"
Ventura
\n",
"
Somis Union Elementary
\n",
"
KK-08
\n",
"
441
\n",
"
20.150000
\n",
"
3.563500
\n",
"
37.193802
\n",
"
45
\n",
"
645.000000
\n",
"
0.102041
\n",
"
4402.831543
\n",
"
21.885857
\n",
"
23.733000
\n",
"
24.263039
\n",
"
648.299988
\n",
"
641.700012
\n",
"
\n",
"
\n",
"
419
\n",
"
72744
\n",
"
Yuba
\n",
"
Plumas Elementary
\n",
"
KK-08
\n",
"
101
\n",
"
5.000000
\n",
"
11.881200
\n",
"
59.405899
\n",
"
14
\n",
"
672.200012
\n",
"
0.138614
\n",
"
4776.336426
\n",
"
20.200001
\n",
"
9.952000
\n",
"
2.970297
\n",
"
667.900024
\n",
"
676.500000
\n",
"
\n",
"
\n",
"
420
\n",
"
72751
\n",
"
Yuba
\n",
"
Wheatland Elementary
\n",
"
KK-08
\n",
"
1778
\n",
"
93.400002
\n",
"
6.923500
\n",
"
47.571201
\n",
"
313
\n",
"
655.750000
\n",
"
0.176041
\n",
"
5993.392578
\n",
"
19.036402
\n",
"
12.502000
\n",
"
5.005624
\n",
"
660.500000
\n",
"
651.000000
\n",
"
\n",
" \n",
"
\n",
"
420 rows × 17 columns
\n",
"
"
],
"text/plain": [
" distcod county district grspan enrltot \\\n",
"1 75119 Alameda Sunol Glen Unified KK-08 195 \n",
"2 61499 Butte Manzanita Elementary KK-08 240 \n",
"3 61549 Butte Thermalito Union Elementary KK-08 1550 \n",
"4 61457 Butte Golden Feather Union Elementary KK-08 243 \n",
"5 61523 Butte Palermo Union Elementary KK-08 1335 \n",
".. ... ... ... ... ... \n",
"416 68957 San Mateo Las Lomitas Elementary KK-08 984 \n",
"417 69518 Santa Clara Los Altos Elementary KK-08 3724 \n",
"418 72611 Ventura Somis Union Elementary KK-08 441 \n",
"419 72744 Yuba Plumas Elementary KK-08 101 \n",
"420 72751 Yuba Wheatland Elementary KK-08 1778 \n",
"\n",
" teachers calwpct mealpct computer testscr compstu \\\n",
"1 10.900000 0.510200 2.040800 67 690.799988 0.343590 \n",
"2 11.150000 15.416700 47.916698 101 661.200012 0.420833 \n",
"3 82.900002 55.032299 76.322601 169 643.599976 0.109032 \n",
"4 14.000000 36.475399 77.049202 85 647.700012 0.349794 \n",
"5 71.500000 33.108601 78.427002 171 640.849976 0.128090 \n",
".. ... ... ... ... ... ... \n",
"416 59.730000 0.101600 3.556900 195 704.300049 0.198171 \n",
"417 208.479996 1.074100 1.503800 721 706.750000 0.193609 \n",
"418 20.150000 3.563500 37.193802 45 645.000000 0.102041 \n",
"419 5.000000 11.881200 59.405899 14 672.200012 0.138614 \n",
"420 93.400002 6.923500 47.571201 313 655.750000 0.176041 \n",
"\n",
" expnstu str avginc elpct readscr mathscr \n",
"1 6384.911133 17.889910 22.690001 0.000000 691.599976 690.000000 \n",
"2 5099.380859 21.524664 9.824000 4.583333 660.500000 661.900024 \n",
"3 5501.954590 18.697226 8.978000 30.000002 636.299988 650.900024 \n",
"4 7101.831055 17.357143 8.978000 0.000000 651.900024 643.500000 \n",
"5 5235.987793 18.671329 9.080333 13.857677 641.799988 639.900024 \n",
".. ... ... ... ... ... ... \n",
"416 7290.338867 16.474134 28.716999 5.995935 700.900024 707.700012 \n",
"417 5741.462891 17.862625 41.734108 4.726101 704.000000 709.500000 \n",
"418 4402.831543 21.885857 23.733000 24.263039 648.299988 641.700012 \n",
"419 4776.336426 20.200001 9.952000 2.970297 667.900024 676.500000 \n",
"420 5993.392578 19.036402 12.502000 5.005624 660.500000 651.000000 \n",
"\n",
"[420 rows x 17 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pydataset.data(\"Caschool\")\n",
"data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The data set relates average test scores in California schools with various data about the schools, such as the amount of students and teachers etc."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before we start, we split the data into a training and a test set."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(123)\n",
"random_indices = np.random.choice([True, False], size=len(data), p=[0.75, 0.25])\n",
"data_train = data.iloc[random_indices]\n",
"data_test = data.iloc[~random_indices]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Simple structure - only inputs and outputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The columns all represent different data about schools in california. We want to predict the average reading score 'readscr', average math score 'mathscr' and average score 'testscr'.\n",
"We do not care for now about the details of the other columns and simply treat all other numerical columns as inputs."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"outputs = {'readscr', 'mathscr', 'testscr'}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'avginc',\n",
" 'calwpct',\n",
" 'compstu',\n",
" 'computer',\n",
" 'elpct',\n",
" 'enrltot',\n",
" 'expnstu',\n",
" 'mealpct',\n",
" 'str',\n",
" 'teachers'}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs = set(data.columns[4:]) - outputs\n",
"inputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define a causal structure."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"causal_structure_1 = CausalStructure([[inputs, outputs]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We provided the training data as scaling data. The CausalStructure will use these in order to apply the correct locations and scales when building the Graph. Alternatively, we could have standardized the data."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we simply train the causal_structure with the training data and test the predictive power on the test data."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"causal_structure_1.train(data_train, method=\"MGVI\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"pl.figure(figsize=(10,10))\n",
"for i, col in [(1, \"mathscr\"), (2, \"readscr\"), (3, \"testscr\")]:\n",
" ax = pl.subplot(2,2,i)\n",
" ax.scatter(data_test[col], test_data_predictions_1[col])\n",
" minval, maxval = data_test[col].min(), data_test[col].max()\n",
" ax.plot([minval, maxval], [minval, maxval], ls=\"--\", c=\"k\")\n",
" ax.set_xlabel(\"real value\")\n",
" ax.set_ylabel(\"prediction\")\n",
" ax.set_title(col)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the predictions seem to follow the real test values rather ok."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Apart from making predictions we can also look at objectives. For example we can evaluate the performance using the `Evaluator`."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'teachers': None,\n",
" 'compstu': None,\n",
" 'mealpct': None,\n",
" 'calwpct': None,\n",
" 'str': None,\n",
" 'avginc': None,\n",
" 'readscr': 0.8020755684153762,\n",
" 'enrltot': None,\n",
" 'mathscr': 0.6574450121854509,\n",
" 'testscr': 0.7577271442716229,\n",
" 'computer': None,\n",
" 'expnstu': None,\n",
" 'elpct': None}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"performance_1 = causal_structure_1.evaluate_objective(Evaluator, data=data_test,\n",
" inputs=inputs, outputs=outputs,\n",
" metric=\"r2\")\n",
"performance_1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see that the R2-score for the inputs columns is returned as `None`. The values for the outputs match with the visual impression of a decent fit."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To judge whether this makes sense we have to know what these columns mean."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caschool\n",
"\n",
"PyDataset Documentation (adopted from R Documentation. The displayed examples are in R)\n",
"\n",
"## The California Test Score Data Set\n",
"\n",
"### Description\n",
"\n",
"a cross-section from 1998-1999\n",
"\n",
"_number of observations_ : 420\n",
"\n",
"_observation_ : schools\n",
"\n",
"_country_ : United States\n",
"\n",
"### Usage\n",
"\n",
" data(Caschool)\n",
"\n",
"### Format\n",
"\n",
"A dataframe containing :\n",
"\n",
"distcod\n",
"\n",
"disctric code\n",
"\n",
"county\n",
"\n",
"county\n",
"\n",
"district\n",
"\n",
"district\n",
"\n",
"grspan\n",
"\n",
"grade span of district\n",
"\n",
"enrltot\n",
"\n",
"total enrollment\n",
"\n",
"teachers\n",
"\n",
"number of teachers\n",
"\n",
"calwpct\n",
"\n",
"percent qualifying for CalWorks\n",
"\n",
"mealpct\n",
"\n",
"percent qualifying for reduced-price lunch\n",
"\n",
"computer\n",
"\n",
"number of computers\n",
"\n",
"testscr\n",
"\n",
"average test score (read.scr+math.scr)/2\n",
"\n",
"compstu\n",
"\n",
"computer per student\n",
"\n",
"expnstu\n",
"\n",
"expenditure per student\n",
"\n",
"str\n",
"\n",
"student teacher ratio\n",
"\n",
"avginc\n",
"\n",
"district average income\n",
"\n",
"elpct\n",
"\n",
"percent of English learners\n",
"\n",
"readscr\n",
"\n",
"average reading score\n",
"\n",
"mathscr\n",
"\n",
"average math score\n",
"\n",
"### Source\n",
"\n",
"California Department of Education http://www.cde.ca.gov.\n",
"\n",
"### References\n",
"\n",
"Stock, James H. and Mark W. Watson (2003) _Introduction to Econometrics_,\n",
"Addison-Wesley Educational Publishers,\n",
"http://wps.aw.com/aw_stockwatsn_economtrcs_1, chapter 4–7.\n",
"\n",
"### See Also\n",
"\n",
"`Index.Source`, `Index.Economics`, `Index.Econometrics`, `Index.Observations`\n",
"\n",
"\n"
]
}
],
"source": [
"pydataset.data(\"Caschool\", show_doc=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see that in the fully connected model (all inputs influence all outputs) the biggest influences seem to be the total enrollment 'enrltot', the amount of teachers 'teachers' adn the percentage of students that get subsidized meals 'mealpct'. The influence of the other tests on 'testscr' is zero.\n",
"\n",
"We might now ask the question whether it makes sense that these quantities are the direct main influences or if they are only confounders. If we take a look at the correlations of the inputs, we see that the inputs have significant correlations that can lead to confounding and the explain-away-effect."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
teachers
compstu
calwpct
mealpct
str
avginc
enrltot
computer
expnstu
elpct
\n",
"
\n",
"
teachers
\n",
"
1.000000
\n",
"
-0.205573
\n",
"
0.105457
\n",
"
0.142855
\n",
"
0.272616
\n",
"
0.023591
\n",
"
0.997424
\n",
"
0.936511
\n",
"
-0.082112
\n",
"
0.357818
\n",
"
\n",
"
\n",
"
compstu
\n",
"
-0.205573
\n",
"
1.000000
\n",
"
-0.171575
\n",
"
-0.217232
\n",
"
-0.301937
\n",
"
0.183565
\n",
"
-0.213354
\n",
"
-0.041737
\n",
"
0.274414
\n",
"
-0.272306
\n",
"
\n",
"
\n",
"
calwpct
\n",
"
0.105457
\n",
"
-0.171575
\n",
"
1.000000
\n",
"
0.735591
\n",
"
0.017576
\n",
"
-0.511519
\n",
"
0.099840
\n",
"
0.072180
\n",
"
0.081346
\n",
"
0.350795
\n",
"
\n",
"
\n",
"
mealpct
\n",
"
0.142855
\n",
"
-0.217232
\n",
"
0.735591
\n",
"
1.000000
\n",
"
0.155836
\n",
"
-0.708045
\n",
"
0.147787
\n",
"
0.076777
\n",
"
-0.031435
\n",
"
0.679559
\n",
"
\n",
"
\n",
"
str
\n",
"
0.272616
\n",
"
-0.301937
\n",
"
0.017576
\n",
"
0.155836
\n",
"
1.000000
\n",
"
-0.222932
\n",
"
0.305878
\n",
"
0.243671
\n",
"
-0.590435
\n",
"
0.226483
\n",
"
\n",
"
\n",
"
avginc
\n",
"
0.023591
\n",
"
0.183565
\n",
"
-0.511519
\n",
"
-0.708045
\n",
"
-0.222932
\n",
"
1.000000
\n",
"
0.010920
\n",
"
0.076274
\n",
"
0.298755
\n",
"
-0.362238
\n",
"
\n",
"
\n",
"
enrltot
\n",
"
0.997424
\n",
"
-0.213354
\n",
"
0.099840
\n",
"
0.147787
\n",
"
0.305878
\n",
"
0.010920
\n",
"
1.000000
\n",
"
0.929746
\n",
"
-0.100289
\n",
"
0.365512
\n",
"
\n",
"
\n",
"
computer
\n",
"
0.936511
\n",
"
-0.041737
\n",
"
0.072180
\n",
"
0.076777
\n",
"
0.243671
\n",
"
0.076274
\n",
"
0.929746
\n",
"
1.000000
\n",
"
-0.061414
\n",
"
0.292201
\n",
"
\n",
"
\n",
"
expnstu
\n",
"
-0.082112
\n",
"
0.274414
\n",
"
0.081346
\n",
"
-0.031435
\n",
"
-0.590435
\n",
"
0.298755
\n",
"
-0.100289
\n",
"
-0.061414
\n",
"
1.000000
\n",
"
-0.049224
\n",
"
\n",
"
\n",
"
elpct
\n",
"
0.357818
\n",
"
-0.272306
\n",
"
0.350795
\n",
"
0.679559
\n",
"
0.226483
\n",
"
-0.362238
\n",
"
0.365512
\n",
"
0.292201
\n",
"
-0.049224
\n",
"
1.000000
\n",
"
\n",
"
"
],
"text/plain": [
""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_train[inputs].corr().style.background_gradient(cmap='coolwarm', vmin=-1, vmax=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see for example that the total enrollment 'enrltot' and the amount of teachers 'teachers' are very strongly correlated (also with the amount of computers)."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
teachers
compstu
calwpct
mealpct
str
avginc
enrltot
computer
expnstu
elpct
\n",
"
\n",
"
testscr
\n",
"
-0.157794
\n",
"
0.278550
\n",
"
-0.628098
\n",
"
-0.877528
\n",
"
-0.227190
\n",
"
0.734062
\n",
"
-0.167217
\n",
"
-0.084573
\n",
"
0.150482
\n",
"
-0.662471
\n",
"
\n",
"
\n",
"
readscr
\n",
"
-0.189992
\n",
"
0.285396
\n",
"
-0.615917
\n",
"
-0.887392
\n",
"
-0.252007
\n",
"
0.720649
\n",
"
-0.199871
\n",
"
-0.117787
\n",
"
0.178198
\n",
"
-0.704534
\n",
"
\n",
"
\n",
"
mathscr
\n",
"
-0.116787
\n",
"
0.260359
\n",
"
-0.617021
\n",
"
-0.832906
\n",
"
-0.191503
\n",
"
0.720223
\n",
"
-0.125351
\n",
"
-0.045295
\n",
"
0.114616
\n",
"
-0.591257
\n",
"
\n",
"
"
],
"text/plain": [
""
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_train.corr().loc[inputs, outputs].T.style.background_gradient(cmap='coolwarm', vmin=-1, vmax=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Advanced structure"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the information we have about the data set we can now try to make a more realistic model of the involved columns.\n",
"\n",
"We start with a number of assumptions which result in dependencies."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"dependencies = []"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First of all we see in the data documentation that 'testscr' is just the average test score."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"dependencies += [\n",
" [[\"mathscr\", \"readscr\"], \"testscr\"]\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assumption 1: We believe the student-teacher-ratio is what actually matters for the test performance. The student teacher ratio of course depends on the number of students and the number of teachers."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"dependencies += [\n",
" [\"str\", [\"mathscr\", \"readscr\"]],\n",
" [[\"teachers\", \"enrltot\"], \"str\"]\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assumption 2: We believe the amount of teachers depends on the number of students and the expenses per student."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"dependencies += [\n",
" [[\"expnstu\", \"enrltot\"], \"teachers\"]\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assumption 3: The amount of computers per student might indicate the learning possibilities at the school and therefore influence the test results. The amount of computer per student of course is determined by the number of students and the number of computers."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"dependencies += [\n",
" [\"compstu\", [\"mathscr\", \"readscr\"]],\n",
" [[\"computer\", \"enrltot\"], \"compstu\"]\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assumption 4: The amount of computers probably depends on the amount of students and the funding per student."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"dependencies += [\n",
" [[\"expnstu\", \"enrltot\"], \"computer\"]\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assumption 5: We assume poorer students achieve worse test results on average. In the data we have the percent qualifying for reduced-price lunch 'mealpct', the percent qualifying for CalWorks 'calwpct' and the average income in the school district 'avginc'. Since the latter is a bit indirect we assume 'mealpct' and 'calwpct' are direct influences on the test results while both of them are influenced by 'avginc'."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"dependencies += [\n",
" [[\"mealpct\", \"calwpct\"], [\"mathscr\", \"readscr\"]],\n",
" [\"avginc\", [\"mealpct\", \"calwpct\"]]\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assumption 6: We assume that being a non-native speaker influences the test results directly as well as the economic situation of the student. The percent of English learners is 'elpct'."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"dependencies += [\n",
" [\"elpct\", [\"mathscr\", \"readscr\"]],\n",
" [\"elpct\", [\"mealpct\", \"calwpct\"]]\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we put in these assumptions into a causal structure."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"causal_structure_2 = CausalStructure(dependencies)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and train it"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"causal_structure_2.train(data_train, method=\"MGVI\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us have a look at the performance on the test data."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"test_data_inputs = data_test[inputs]\n",
"test_data_predictions_2 = causal_structure_2.predict(test_data_inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First visually..."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import pylab as pl\n",
"pl.figure(figsize=(10,10))\n",
"for i, col in [(1, \"mathscr\"), (2, \"readscr\"), (3, \"testscr\")]:\n",
" ax = pl.subplot(2,2,i)\n",
" ax.scatter(data_test[col], test_data_predictions_2[col])\n",
" minval, maxval = data_test[col].min(), data_test[col].max()\n",
" ax.plot([minval, maxval], [minval, maxval], ls=\"--\", c=\"k\")\n",
" ax.set_xlabel(\"real value\")\n",
" ax.set_ylabel(\"prediction\")\n",
" ax.set_title(col)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and by the R2 score."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"performance_2 = causal_structure_2.evaluate_objective(Evaluator, data=data_test,\n",
" inputs=inputs, outputs=outputs,\n",
" metric=\"r2\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'performance of simple model:'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'readscr': 0.8020755684153762,\n",
" 'mathscr': 0.6574450121854509,\n",
" 'testscr': 0.7577271442716229}"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'performance of structured model:'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'readscr': 0.7736833098673347,\n",
" 'testscr': 0.733377201980135,\n",
" 'mathscr': 0.6371819769563171}"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(\"performance of simple model:\")\n",
"display({\n",
" key: value for key, value in performance_1.items() if key in outputs\n",
"})\n",
"display(\"performance of structured model:\")\n",
"display({\n",
" key: value for key, value in performance_2.items() if key in outputs\n",
"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the R2 scores are actually slightly worse than for the simple causal structure, which only had inputs and outputs. Maybe some of our assumptions were incomplete or even wrong.\n",
"\n",
"So what have we gained if the prediction performance is not better? We can see that once we start comparing predicions on partial data."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Comparing the performance of the simple and the advanced structure."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we established the performance when predicting the outputs based on all inputs is the same in both models. When all input data are available and the training set and the test set are split randomly it does not matter that much whether the learned behavior is based on real causality or confounders.\n",
"\n",
"However, if we make predictions from incomplete data, this changes."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Example 1: predicting based on subsidized meals and CalWorks\n",
"Let's try a prediction based on only the amount of subsidized meals and the percentage in the CalWorks program."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'prediction based on simple model:'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'readscr': 0.6342623417400537,\n",
" 'mathscr': 0.5509301033970446,\n",
" 'testscr': 0.590015453901932}"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'prediction based on structured model:'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'readscr': 0.7078249997685941,\n",
" 'testscr': 0.6881614958989846,\n",
" 'mathscr': 0.6110603334844447}"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(\"prediction based on simple model:\")\n",
"display({\n",
" key: value for key, value in \n",
" causal_structure_1.evaluate_objective(Evaluator, data=data_test,\n",
" inputs=['calwpct', 'mealpct'], outputs=outputs,\n",
" metric=\"r2\").items()\n",
" if key in outputs\n",
"})\n",
"display(\"prediction based on structured model:\")\n",
"display({\n",
" key: value for key, value in \n",
" causal_structure_2.evaluate_objective(Evaluator, data=data_test,\n",
" inputs=['calwpct', 'mealpct'], outputs=outputs,\n",
" metric=\"r2\").items()\n",
" if key in outputs\n",
"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see that in this scenario the structured model outperforms the simple model, but both models still achieve decent scores."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Example 2: predicting based on CalWorks only\n",
"Let's try a prediction based on only the percentage in the CalWorks program."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'prediction based on simple model:'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'readscr': -0.033142940725477965,\n",
" 'mathscr': 0.11518897669104289,\n",
" 'testscr': 0.029182778885538996}"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'prediction based on structured model:'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'readscr': 0.32351923326782916,\n",
" 'testscr': 0.3688571477752969,\n",
" 'mathscr': 0.37847951309032846}"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(\"prediction based on simple model:\")\n",
"display({\n",
" key: value for key, value in \n",
" causal_structure_1.evaluate_objective(Evaluator, data=data_test,\n",
" inputs=['calwpct'], outputs=outputs,\n",
" metric=\"r2\").items()\n",
" if key in outputs\n",
"})\n",
"display(\"prediction based on structured model:\")\n",
"display({\n",
" key: value for key, value in \n",
" causal_structure_2.evaluate_objective(Evaluator, data=data_test,\n",
" inputs=['calwpct'], outputs=outputs,\n",
" metric=\"r2\").items()\n",
" if key in outputs\n",
"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this scenario the structured model outperforms the simple model significantly."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Example 3: predicting based on district average income\n",
"Let's try a prediction based on only the district average income."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'prediction based on simple model:'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'readscr': 0.2816587802736583,\n",
" 'mathscr': 0.2002535289464773,\n",
" 'testscr': 0.2320481281685418}"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'prediction based on structured model:'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'readscr': 0.34039792627046295,\n",
" 'testscr': 0.3543316153774535,\n",
" 'mathscr': 0.3358698650229551}"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(\"prediction based on simple model:\")\n",
"display({\n",
" key: value for key, value in \n",
" causal_structure_1.evaluate_objective(Evaluator, data=data_test,\n",
" inputs=['avginc'], outputs=outputs,\n",
" metric=\"r2\").items()\n",
" if key in outputs\n",
"})\n",
"display(\"prediction based on structured model:\")\n",
"display({\n",
" key: value for key, value in \n",
" causal_structure_2.evaluate_objective(Evaluator, data=data_test,\n",
" inputs=['avginc'], outputs=outputs,\n",
" metric=\"r2\").items()\n",
" if key in outputs\n",
"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this scenario the structured model outperforms the simple model significantly."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To summarize, a causal structure based on realistic dependency assumptions will not necessarily outperform a black-box model with only inputs and outputs. It is however, more robust when making predicions based on incomplete data."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}