Causal Structures - applied on the California School data set

The imports

[1]:
import numpy as np
import pydataset
import pylab as pl

from halerium import CausalStructure
from halerium import Evaluator

The California Test Score Data Set

In this example we apply the CausalStructure class to the California Test Score Data Set. The preferred data format to use with CausalStructure is the pandas DataFrame.

[2]:
data = pydataset.data("Caschool")
data
[2]:
distcod county district grspan enrltot teachers calwpct mealpct computer testscr compstu expnstu str avginc elpct readscr mathscr
1 75119 Alameda Sunol Glen Unified KK-08 195 10.900000 0.510200 2.040800 67 690.799988 0.343590 6384.911133 17.889910 22.690001 0.000000 691.599976 690.000000
2 61499 Butte Manzanita Elementary KK-08 240 11.150000 15.416700 47.916698 101 661.200012 0.420833 5099.380859 21.524664 9.824000 4.583333 660.500000 661.900024
3 61549 Butte Thermalito Union Elementary KK-08 1550 82.900002 55.032299 76.322601 169 643.599976 0.109032 5501.954590 18.697226 8.978000 30.000002 636.299988 650.900024
4 61457 Butte Golden Feather Union Elementary KK-08 243 14.000000 36.475399 77.049202 85 647.700012 0.349794 7101.831055 17.357143 8.978000 0.000000 651.900024 643.500000
5 61523 Butte Palermo Union Elementary KK-08 1335 71.500000 33.108601 78.427002 171 640.849976 0.128090 5235.987793 18.671329 9.080333 13.857677 641.799988 639.900024
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
416 68957 San Mateo Las Lomitas Elementary KK-08 984 59.730000 0.101600 3.556900 195 704.300049 0.198171 7290.338867 16.474134 28.716999 5.995935 700.900024 707.700012
417 69518 Santa Clara Los Altos Elementary KK-08 3724 208.479996 1.074100 1.503800 721 706.750000 0.193609 5741.462891 17.862625 41.734108 4.726101 704.000000 709.500000
418 72611 Ventura Somis Union Elementary KK-08 441 20.150000 3.563500 37.193802 45 645.000000 0.102041 4402.831543 21.885857 23.733000 24.263039 648.299988 641.700012
419 72744 Yuba Plumas Elementary KK-08 101 5.000000 11.881200 59.405899 14 672.200012 0.138614 4776.336426 20.200001 9.952000 2.970297 667.900024 676.500000
420 72751 Yuba Wheatland Elementary KK-08 1778 93.400002 6.923500 47.571201 313 655.750000 0.176041 5993.392578 19.036402 12.502000 5.005624 660.500000 651.000000

420 rows × 17 columns

The data set relates average test scores in California schools with various data about the schools, such as the amount of students and teachers etc.

Before we start, we split the data into a training and a test set.

[3]:
np.random.seed(123)
random_indices = np.random.choice([True, False], size=len(data), p=[0.75, 0.25])
data_train = data.iloc[random_indices]
data_test = data.iloc[~random_indices]

Simple structure - only inputs and outputs

The columns all represent different data about schools in california. We want to predict the average reading score ‘readscr’, average math score ‘mathscr’ and average score ‘testscr’. We do not care for now about the details of the other columns and simply treat all other numerical columns as inputs.

[4]:
outputs = {'readscr', 'mathscr', 'testscr'}
[5]:
inputs = set(data.columns[4:]) - outputs
inputs
[5]:
{'avginc',
 'calwpct',
 'compstu',
 'computer',
 'elpct',
 'enrltot',
 'expnstu',
 'mealpct',
 'str',
 'teachers'}

Now we define a causal structure.

[6]:
causal_structure_1 = CausalStructure([[inputs, outputs]])

We provided the training data as scaling data. The CausalStructure will use these in order to apply the correct locations and scales when building the Graph. Alternatively, we could have standardized the data.

Now we simply train the causal_structure with the training data and test the predictive power on the test data.

[7]:
causal_structure_1.train(data_train, method="MGVI")
[8]:
test_data_inputs = data_test[inputs]
test_data_predictions_1 = causal_structure_1.predict(test_data_inputs)
test_data_predictions_1
[8]:
teachers compstu mealpct calwpct str avginc readscr enrltot mathscr testscr computer expnstu elpct
7 10.000000 0.143590 94.623703 12.903200 19.500000 6.577000 617.612475 195.0 625.262781 621.376790 28.0 5253.331055 68.717949
22 21.000000 0.111579 91.546402 21.649500 22.619047 9.630000 627.282873 475.0 631.942011 629.668145 53.0 4542.104980 16.210526
38 17.500000 0.063380 94.932404 14.527000 16.228571 14.558000 628.064876 284.0 631.237212 629.626784 18.0 6516.533203 32.394367
39 280.000000 0.104655 81.117302 19.571699 19.178572 22.059999 632.454008 5370.0 639.352250 635.985587 562.0 4559.176758 65.512100
45 221.000000 0.076092 92.743103 24.659500 19.266968 10.050500 622.697385 4258.0 630.101735 626.458898 324.0 5092.917480 36.801315
48 39.000000 0.104423 67.444702 19.778900 20.871796 15.177000 643.666348 814.0 643.284873 643.517021 85.0 4518.016113 13.759213
52 74.500000 0.151250 79.725502 27.261400 21.476511 10.472000 630.008711 1600.0 632.956701 631.485806 242.0 4720.086426 35.250000
59 126.120003 0.054150 52.396400 15.459500 22.256580 12.997000 642.490168 2807.0 644.200948 643.344747 152.0 4353.019531 33.487709
65 50.000000 0.155695 82.131699 27.168200 19.139999 9.082000 633.654592 957.0 634.547127 634.078879 149.0 5306.132812 17.659353
67 209.839996 0.065438 93.922699 17.702600 20.682425 14.601625 625.472844 4340.0 632.377607 629.086134 284.0 5181.579590 31.198156
72 141.210007 0.091396 68.396400 21.258801 21.152891 13.730599 641.871966 2987.0 641.980173 642.028016 273.0 4826.666016 14.429193
85 26.500000 0.075949 75.738403 21.518999 17.886793 8.258000 632.714856 474.0 633.982099 633.323229 36.0 5329.002441 28.691982
86 57.700001 0.248654 62.212799 29.446800 19.306759 8.896000 650.578864 1114.0 645.014199 647.825519 277.0 5929.809082 1.795332
92 15.000000 0.183333 72.185402 48.344398 20.000000 11.553000 647.135993 300.0 640.133688 643.631130 55.0 5869.389160 0.666667
94 8.000000 0.232877 54.109600 25.342501 18.250000 7.105000 654.489627 146.0 647.243790 650.944615 34.0 6231.601562 0.000000
107 846.380005 0.060922 51.242100 20.523701 22.923510 14.127667 647.845877 19402.0 646.546757 647.165516 1182.0 4906.229980 18.436245
109 136.830002 0.151469 64.097702 25.829800 19.155155 13.390000 650.950101 2621.0 646.368921 648.613525 397.0 5718.551270 2.022129
114 332.100006 0.076557 71.023804 12.279100 19.626617 14.242901 634.948600 6518.0 639.780808 637.422017 499.0 5172.019531 43.494938
120 26.610001 0.202532 82.067497 20.675100 17.812851 11.238000 630.479794 474.0 634.182173 632.480049 96.0 4945.031738 35.443039
121 30.000000 0.266544 74.862396 20.550501 18.133333 10.056000 641.267530 544.0 641.698759 641.549939 145.0 5223.025391 8.639706
125 114.500000 0.041667 53.210602 14.414000 19.283842 9.630000 647.832903 2208.0 647.167778 647.609700 92.0 4960.515625 9.873188
126 55.000000 0.131474 44.259701 11.322200 22.818182 15.274000 653.823708 1255.0 651.019501 652.490324 165.0 4880.040527 16.095617
129 98.000000 0.086646 67.751801 24.415100 20.020409 9.485000 639.821536 1962.0 639.065021 639.379896 170.0 5204.687500 18.195719
134 39.500000 0.000000 34.078899 11.710500 19.037975 14.076000 662.788831 752.0 656.968347 659.849665 0.0 5500.107910 0.132979
135 537.880005 0.184284 39.036800 7.487500 17.342157 25.487333 658.662077 9328.0 663.079096 661.096120 1719.0 5360.517090 50.857632
143 391.420013 0.158389 71.912102 12.399000 21.501200 12.669900 633.088379 8416.0 639.744562 636.421265 1333.0 5065.911133 43.750000
146 232.970001 0.130312 65.915001 17.476101 19.796539 12.827000 642.932220 4612.0 643.894790 643.420662 601.0 5124.836426 16.652212
150 6.000000 0.000000 35.820900 11.194000 22.166666 11.426000 659.636291 133.0 654.006918 656.901619 0.0 5213.089355 0.751880
152 11.660000 0.135135 40.625000 7.589300 19.039452 14.578000 656.280216 222.0 653.613198 655.002750 30.0 4886.983887 13.513513
158 11.500000 0.154472 43.902401 9.756100 21.391304 11.081000 652.815764 246.0 649.519740 651.190263 38.0 5161.202148 16.666668
162 39.900002 0.068776 32.187099 10.178800 18.220551 9.972000 655.948312 727.0 653.918119 654.969426 50.0 4842.607910 13.067400
166 38.000000 0.075282 35.759102 7.026300 20.973684 16.292999 661.393048 797.0 658.033962 659.672199 60.0 4674.298828 3.262233
168 13.700000 0.217021 67.234001 6.808500 17.153284 16.622999 641.824404 235.0 642.706149 642.305920 51.0 5621.687988 39.574467
169 371.100006 0.156137 51.567402 23.197500 22.349771 12.549883 649.462177 8294.0 649.298686 649.430883 1295.0 4948.868164 10.127803
171 8.250000 0.106667 23.225800 3.225800 18.181818 18.326000 665.447342 150.0 661.930925 663.703699 16.0 5132.788574 14.000000
173 117.800003 0.148323 52.252300 17.975100 19.745331 11.426000 652.022155 2326.0 649.161762 650.526578 345.0 5149.186523 6.405847
174 30.500000 0.325349 54.761902 20.634899 16.426229 8.830000 652.584348 501.0 650.315801 651.378747 163.0 5373.206543 2.395210
175 28.270000 0.102128 89.699600 31.330500 16.625397 11.176000 636.501320 470.0 634.375519 635.416954 48.0 6485.166016 5.957447
184 41.150002 0.096154 50.763401 16.539400 18.955042 14.197000 655.280541 780.0 650.376220 652.903479 75.0 5261.370605 3.846154
186 6.750000 0.178571 45.714298 7.857100 20.740740 10.639000 652.446694 140.0 651.186929 651.756556 25.0 4566.270020 10.714286
189 5.000000 0.231481 32.110100 3.669700 21.600000 9.665000 659.411046 108.0 658.471727 658.905142 25.0 4432.478027 4.629630
204 15.850000 0.117264 28.664499 5.537500 19.369085 14.578000 662.314650 307.0 659.310833 660.853105 36.0 4718.163086 7.491857
205 17.500000 0.161383 53.602299 18.443800 19.828571 10.202000 649.437663 347.0 646.738300 648.038495 56.0 4751.299805 9.798271
211 61.990002 0.145261 62.267502 34.189499 18.212614 11.291000 651.567296 1129.0 644.711060 648.104441 164.0 5805.636719 0.000000
214 229.199997 0.125406 25.888500 8.542000 21.500874 14.623000 663.465537 4928.0 660.573657 661.936855 618.0 5139.937500 5.925324
219 43.500000 0.110983 12.574100 3.084200 19.885057 14.578000 669.613275 865.0 666.541670 668.121803 96.0 4797.503418 3.352601
222 329.119995 0.097913 36.889999 12.678500 19.363758 14.597667 658.688447 6373.0 658.211347 658.471972 624.0 4733.746582 6.260787
224 138.119995 0.066138 34.447102 11.126400 21.017956 16.271999 663.085644 2903.0 657.967585 660.539652 192.0 5431.177246 3.169135
225 29.650000 0.106195 36.637199 6.725700 19.055649 13.630000 657.091211 565.0 653.529834 655.365003 60.0 5382.088867 16.991150
234 188.199997 0.119102 49.426899 9.770800 18.692881 22.473000 653.031523 3518.0 652.258787 652.659949 419.0 5642.832031 40.108017
235 924.570007 0.124443 50.970798 18.931900 20.868078 15.684654 652.773767 19294.0 653.725935 653.493526 2401.0 5280.023438 16.663212
238 6.000000 0.358974 36.134499 8.403400 19.500000 16.356001 668.170509 117.0 662.806404 665.457846 42.0 6039.985840 0.854701
243 242.500000 0.091066 37.291100 8.318900 21.463917 15.749917 653.984485 5205.0 653.846936 653.818800 474.0 4954.466309 22.939482
247 153.199997 0.103080 26.943300 3.530500 20.770235 12.900000 660.471034 3182.0 659.690492 659.956172 328.0 4777.512695 8.956632
249 588.840027 0.127119 32.765499 9.015800 20.132803 17.435200 659.251062 11855.0 659.739467 659.553594 1507.0 5486.029785 20.725431
250 51.669998 0.120787 36.797798 12.827700 20.669636 14.760000 663.553877 1068.0 656.782451 659.998883 129.0 5486.382324 0.187266
256 7.000000 0.193277 33.599998 4.000000 17.000000 18.326000 660.448138 119.0 656.326883 658.311479 23.0 5990.794434 28.571430
258 27.600000 0.109890 54.761902 13.919400 19.782608 10.098000 646.835857 546.0 645.633578 646.294001 60.0 4777.769043 13.736264
263 115.300003 0.129892 15.655900 2.881700 20.164787 21.110500 667.167368 2325.0 666.148685 666.576950 302.0 4890.986816 22.709679
281 46.650002 0.127737 25.605499 8.073800 17.620579 14.163000 667.893071 822.0 661.581156 664.708583 105.0 5784.625000 0.851582
284 27.350000 0.112621 26.213600 16.116501 18.829981 14.209000 666.968963 515.0 659.846349 663.479007 58.0 5190.310059 0.000000
287 46.000000 0.154313 55.285500 13.851800 17.891304 10.551000 652.062665 823.0 648.806330 650.402278 127.0 5331.920898 2.187120
288 114.300003 0.115643 13.369200 2.916100 19.518810 16.955999 669.367531 2231.0 665.806105 667.439416 258.0 5599.607910 13.626176
290 15.500000 0.275081 32.686100 4.207100 19.935484 16.271999 665.115102 309.0 662.837178 664.027088 85.0 4700.432129 1.941748
292 47.959999 0.118012 47.305401 20.958099 20.141785 11.834000 656.883658 966.0 650.485988 653.693831 114.0 5318.240723 0.621118
294 35.500000 0.070866 30.373199 10.810800 21.464788 12.431000 663.358136 762.0 657.475610 660.403959 54.0 5143.190430 0.000000
296 489.299988 0.112995 30.643700 3.335000 20.130800 20.875750 661.465628 9850.0 663.110579 662.196270 1113.0 5081.455566 18.558376
297 5.000000 0.077519 50.387600 9.302300 25.799999 10.639000 648.618780 129.0 648.549217 648.603046 10.0 4016.416260 6.201550
298 565.510010 0.108296 28.376600 2.538400 18.777740 25.029619 666.036055 10619.0 667.101409 666.739172 1150.0 5374.050293 18.457481
306 33.700001 0.128889 31.851900 0.000000 20.029673 21.957001 663.041307 675.0 662.365592 662.682382 87.0 4543.305176 14.962962
316 10.880000 0.237500 53.503201 13.375800 14.705882 11.826000 657.115402 160.0 650.725460 653.880704 38.0 6870.346191 2.500000
318 108.750000 0.135123 16.065100 4.049300 20.891954 17.332001 671.475950 2272.0 667.167409 669.243356 307.0 5182.365723 1.320423
323 245.179993 0.137522 23.642000 7.037000 18.892242 20.469999 667.193966 4632.0 664.397837 665.754992 637.0 5666.713867 17.854059
328 5.500000 0.084746 51.666698 8.333300 21.454546 8.934000 648.312381 118.0 647.954079 648.068150 10.0 4451.256836 7.627119
330 232.750000 0.123785 19.372499 2.607800 20.339420 18.827230 666.844189 4734.0 666.145109 666.526671 586.0 4922.939941 10.857626
332 286.920013 0.089182 30.085100 5.613400 21.103443 23.667376 663.895376 6055.0 663.783643 663.791609 540.0 4829.717285 14.236169
334 139.300003 0.219564 23.098900 5.033900 20.107681 20.546000 671.557943 2801.0 668.080271 669.752223 615.0 5205.941895 2.070689
338 65.900002 0.095710 46.122101 22.194700 18.391502 10.551000 654.703879 1212.0 650.066687 652.415548 116.0 5090.587891 3.960396
341 288.630005 0.148634 15.512800 2.163500 21.678272 21.095751 669.614810 6257.0 669.478912 669.555534 930.0 4889.479980 8.726227
342 45.000000 0.140553 19.009199 4.723500 19.288889 15.167000 669.900481 868.0 665.261388 667.561157 122.0 5118.130371 0.000000
349 70.500000 0.126033 10.414200 2.426000 20.595745 17.656000 673.992356 1452.0 669.820784 672.024306 183.0 5028.561523 0.000000
350 8.000000 0.206452 12.258100 0.645200 19.375000 17.709000 674.467598 155.0 668.955346 671.693862 32.0 5695.466797 4.516129
352 30.080000 0.181658 29.021000 4.895100 18.849733 18.593000 668.890252 567.0 663.800812 666.315450 103.0 5156.437012 0.000000
371 25.500000 0.111359 48.484798 12.727300 17.607843 12.934000 658.155302 449.0 651.847919 654.937895 50.0 5884.432129 0.000000
375 5.100000 0.172840 0.000000 3.614500 15.882353 22.528999 687.413953 81.0 677.710248 682.690134 14.0 7667.571777 0.000000
377 108.959999 0.113776 3.730200 1.941700 17.988253 24.603001 681.105457 1960.0 677.414093 679.339899 223.0 5368.502930 1.479592
379 49.169998 0.327696 6.535300 2.178400 19.239374 33.455002 687.972117 946.0 685.793505 686.758283 310.0 5387.458984 1.479915
382 46.000000 0.228571 10.651500 4.239900 20.543478 19.346500 675.490470 945.0 671.170728 673.307374 216.0 5088.708008 2.539683
384 10.510000 0.213415 12.941200 2.352900 15.604186 22.863335 676.657773 164.0 670.484640 673.555717 35.0 6588.430664 11.585366
388 6.000000 0.201493 3.731300 0.746300 22.333334 21.957001 679.589037 134.0 675.039429 677.465932 27.0 5094.958984 2.238806
393 28.020000 0.062738 11.406800 3.992400 18.772305 13.567000 671.865756 526.0 666.245290 669.097893 33.0 5644.286133 0.000000
400 4.850000 0.259259 20.481899 13.253000 16.701031 30.840000 683.225780 81.0 673.540505 678.458755 21.0 7614.379395 12.345679
403 146.300003 0.158930 9.677400 0.826100 17.375256 27.947750 681.855225 2542.0 676.931333 679.498758 404.0 6604.063965 10.385523
405 65.120003 0.256846 1.416400 0.849900 16.262285 55.327999 705.686334 1059.0 708.851129 707.182838 272.0 6460.657227 2.266289
413 12.330000 0.100000 0.000000 0.454500 17.842661 43.230000 697.051592 220.0 693.210634 695.000096 22.0 6500.449707 1.363636
420 93.400002 0.176041 47.571201 6.923500 19.036402 12.502000 657.326803 1778.0 652.903342 655.114523 313.0 5993.392578 5.005624

We see that the predict method completes the provided data frame filling in the missing columns with estimates. Let us check the predictions by looking at a scatter plot if the true test data values vs. the predictions.

[9]:
pl.figure(figsize=(10,10))
for i, col in [(1, "mathscr"), (2, "readscr"), (3, "testscr")]:
    ax = pl.subplot(2,2,i)
    ax.scatter(data_test[col], test_data_predictions_1[col])
    minval, maxval = data_test[col].min(), data_test[col].max()
    ax.plot([minval, maxval], [minval, maxval], ls="--", c="k")
    ax.set_xlabel("real value")
    ax.set_ylabel("prediction")
    ax.set_title(col)
../../_images/examples_04_causal_structure_03-causal_structure_calschool_20_0.png

We can see that the predictions seem to follow the real test values rather ok.

Apart from making predictions we can also look at objectives. For example we can evaluate the performance using the Evaluator.

[10]:
performance_1 = causal_structure_1.evaluate_objective(Evaluator, data=data_test,
                                                      inputs=inputs, outputs=outputs,
                                                      metric="r2")
performance_1
[10]:
{'teachers': None,
 'compstu': None,
 'mealpct': None,
 'calwpct': None,
 'str': None,
 'avginc': None,
 'readscr': 0.8020755684153762,
 'enrltot': None,
 'mathscr': 0.6574450121854509,
 'testscr': 0.7577271442716229,
 'computer': None,
 'expnstu': None,
 'elpct': None}

We see that the R2-score for the inputs columns is returned as None. The values for the outputs match with the visual impression of a decent fit.

To judge whether this makes sense we have to know what these columns mean.

[11]:
pydataset.data("Caschool", show_doc=True)
Caschool

PyDataset Documentation (adopted from R Documentation. The displayed examples are in R)

## The California Test Score Data Set

### Description

a cross-section from 1998-1999

_number of observations_ : 420

_observation_ : schools

_country_ : United States

### Usage

    data(Caschool)

### Format

A dataframe containing :

distcod

disctric code

county

county

district

district

grspan

grade span of district

enrltot

total enrollment

teachers

number of teachers

calwpct

percent qualifying for CalWorks

mealpct

percent qualifying for reduced-price lunch

computer

number of computers

testscr

average test score (read.scr+math.scr)/2

compstu

computer per student

expnstu

expenditure per student

str

student teacher ratio

avginc

district average income

elpct

percent of English learners

readscr

average reading score

mathscr

average math score

### Source

California Department of Education http://www.cde.ca.gov.

### References

Stock, James H. and Mark W. Watson (2003) _Introduction to Econometrics_,
Addison-Wesley Educational Publishers,
http://wps.aw.com/aw_stockwatsn_economtrcs_1, chapter 4–7.

### See Also

`Index.Source`, `Index.Economics`, `Index.Econometrics`, `Index.Observations`


We see that in the fully connected model (all inputs influence all outputs) the biggest influences seem to be the total enrollment ‘enrltot’, the amount of teachers ‘teachers’ adn the percentage of students that get subsidized meals ‘mealpct’. The influence of the other tests on ‘testscr’ is zero.

We might now ask the question whether it makes sense that these quantities are the direct main influences or if they are only confounders. If we take a look at the correlations of the inputs, we see that the inputs have significant correlations that can lead to confounding and the explain-away-effect.

[12]:
data_train[inputs].corr().style.background_gradient(cmap='coolwarm', vmin=-1, vmax=1)
[12]:
teachers compstu calwpct mealpct str avginc enrltot computer expnstu elpct
teachers 1.000000 -0.205573 0.105457 0.142855 0.272616 0.023591 0.997424 0.936511 -0.082112 0.357818
compstu -0.205573 1.000000 -0.171575 -0.217232 -0.301937 0.183565 -0.213354 -0.041737 0.274414 -0.272306
calwpct 0.105457 -0.171575 1.000000 0.735591 0.017576 -0.511519 0.099840 0.072180 0.081346 0.350795
mealpct 0.142855 -0.217232 0.735591 1.000000 0.155836 -0.708045 0.147787 0.076777 -0.031435 0.679559
str 0.272616 -0.301937 0.017576 0.155836 1.000000 -0.222932 0.305878 0.243671 -0.590435 0.226483
avginc 0.023591 0.183565 -0.511519 -0.708045 -0.222932 1.000000 0.010920 0.076274 0.298755 -0.362238
enrltot 0.997424 -0.213354 0.099840 0.147787 0.305878 0.010920 1.000000 0.929746 -0.100289 0.365512
computer 0.936511 -0.041737 0.072180 0.076777 0.243671 0.076274 0.929746 1.000000 -0.061414 0.292201
expnstu -0.082112 0.274414 0.081346 -0.031435 -0.590435 0.298755 -0.100289 -0.061414 1.000000 -0.049224
elpct 0.357818 -0.272306 0.350795 0.679559 0.226483 -0.362238 0.365512 0.292201 -0.049224 1.000000

We see for example that the total enrollment ‘enrltot’ and the amount of teachers ‘teachers’ are very strongly correlated (also with the amount of computers).

[13]:
data_train.corr().loc[inputs, outputs].T.style.background_gradient(cmap='coolwarm', vmin=-1, vmax=1)
[13]:
teachers compstu calwpct mealpct str avginc enrltot computer expnstu elpct
testscr -0.157794 0.278550 -0.628098 -0.877528 -0.227190 0.734062 -0.167217 -0.084573 0.150482 -0.662471
readscr -0.189992 0.285396 -0.615917 -0.887392 -0.252007 0.720649 -0.199871 -0.117787 0.178198 -0.704534
mathscr -0.116787 0.260359 -0.617021 -0.832906 -0.191503 0.720223 -0.125351 -0.045295 0.114616 -0.591257

Advanced structure

With the information we have about the data set we can now try to make a more realistic model of the involved columns.

We start with a number of assumptions which result in dependencies.

[14]:
dependencies = []

First of all we see in the data documentation that ‘testscr’ is just the average test score.

[15]:
dependencies += [
    [["mathscr", "readscr"], "testscr"]
]

Assumption 1: We believe the student-teacher-ratio is what actually matters for the test performance. The student teacher ratio of course depends on the number of students and the number of teachers.

[16]:
dependencies += [
    ["str", ["mathscr", "readscr"]],
    [["teachers", "enrltot"], "str"]
]

Assumption 2: We believe the amount of teachers depends on the number of students and the expenses per student.

[17]:
dependencies += [
    [["expnstu", "enrltot"], "teachers"]
]

Assumption 3: The amount of computers per student might indicate the learning possibilities at the school and therefore influence the test results. The amount of computer per student of course is determined by the number of students and the number of computers.

[18]:
dependencies += [
    ["compstu", ["mathscr", "readscr"]],
    [["computer", "enrltot"], "compstu"]
]

Assumption 4: The amount of computers probably depends on the amount of students and the funding per student.

[19]:
dependencies += [
    [["expnstu", "enrltot"], "computer"]
]

Assumption 5: We assume poorer students achieve worse test results on average. In the data we have the percent qualifying for reduced-price lunch ‘mealpct’, the percent qualifying for CalWorks ‘calwpct’ and the average income in the school district ‘avginc’. Since the latter is a bit indirect we assume ‘mealpct’ and ‘calwpct’ are direct influences on the test results while both of them are influenced by ‘avginc’.

[20]:
dependencies += [
    [["mealpct", "calwpct"], ["mathscr", "readscr"]],
    ["avginc", ["mealpct", "calwpct"]]
]

Assumption 6: We assume that being a non-native speaker influences the test results directly as well as the economic situation of the student. The percent of English learners is ‘elpct’.

[21]:
dependencies += [
    ["elpct", ["mathscr", "readscr"]],
    ["elpct", ["mealpct", "calwpct"]]
]

Now we put in these assumptions into a causal structure.

[22]:
causal_structure_2 = CausalStructure(dependencies)

and train it

[23]:
causal_structure_2.train(data_train, method="MGVI")

Let us have a look at the performance on the test data.

[24]:
test_data_inputs = data_test[inputs]
test_data_predictions_2 = causal_structure_2.predict(test_data_inputs)

First visually…

[25]:
import pylab as pl
pl.figure(figsize=(10,10))
for i, col in [(1, "mathscr"), (2, "readscr"), (3, "testscr")]:
    ax = pl.subplot(2,2,i)
    ax.scatter(data_test[col], test_data_predictions_2[col])
    minval, maxval = data_test[col].min(), data_test[col].max()
    ax.plot([minval, maxval], [minval, maxval], ls="--", c="k")
    ax.set_xlabel("real value")
    ax.set_ylabel("prediction")
    ax.set_title(col)
../../_images/examples_04_causal_structure_03-causal_structure_calschool_55_0.png

and by the R2 score.

[26]:
performance_2 = causal_structure_2.evaluate_objective(Evaluator, data=data_test,
                                                      inputs=inputs, outputs=outputs,
                                                      metric="r2")
[27]:
display("performance of simple model:")
display({
    key: value for key, value in performance_1.items() if key in outputs
})
display("performance of structured model:")
display({
    key: value for key, value in performance_2.items() if key in outputs
})
'performance of simple model:'
{'readscr': 0.8020755684153762,
 'mathscr': 0.6574450121854509,
 'testscr': 0.7577271442716229}
'performance of structured model:'
{'readscr': 0.7736833098673347,
 'testscr': 0.733377201980135,
 'mathscr': 0.6371819769563171}

We can see that the R2 scores are actually slightly worse than for the simple causal structure, which only had inputs and outputs. Maybe some of our assumptions were incomplete or even wrong.

So what have we gained if the prediction performance is not better? We can see that once we start comparing predicions on partial data.

Comparing the performance of the simple and the advanced structure.

As we established the performance when predicting the outputs based on all inputs is the same in both models. When all input data are available and the training set and the test set are split randomly it does not matter that much whether the learned behavior is based on real causality or confounders.

However, if we make predictions from incomplete data, this changes.

Example 1: predicting based on subsidized meals and CalWorks

Let’s try a prediction based on only the amount of subsidized meals and the percentage in the CalWorks program.

[28]:
display("prediction based on simple model:")
display({
    key: value for key, value in
    causal_structure_1.evaluate_objective(Evaluator, data=data_test,
                                          inputs=['calwpct', 'mealpct'], outputs=outputs,
                                          metric="r2").items()
    if key in outputs
})
display("prediction based on structured model:")
display({
    key: value for key, value in
    causal_structure_2.evaluate_objective(Evaluator, data=data_test,
                                          inputs=['calwpct', 'mealpct'], outputs=outputs,
                                          metric="r2").items()
    if key in outputs
})
'prediction based on simple model:'
{'readscr': 0.6342623417400537,
 'mathscr': 0.5509301033970446,
 'testscr': 0.590015453901932}
'prediction based on structured model:'
{'readscr': 0.7078249997685941,
 'testscr': 0.6881614958989846,
 'mathscr': 0.6110603334844447}

We see that in this scenario the structured model outperforms the simple model, but both models still achieve decent scores.

Example 2: predicting based on CalWorks only

Let’s try a prediction based on only the percentage in the CalWorks program.

[29]:
display("prediction based on simple model:")
display({
    key: value for key, value in
    causal_structure_1.evaluate_objective(Evaluator, data=data_test,
                                          inputs=['calwpct'], outputs=outputs,
                                          metric="r2").items()
    if key in outputs
})
display("prediction based on structured model:")
display({
    key: value for key, value in
    causal_structure_2.evaluate_objective(Evaluator, data=data_test,
                                          inputs=['calwpct'], outputs=outputs,
                                          metric="r2").items()
    if key in outputs
})
'prediction based on simple model:'
{'readscr': -0.033142940725477965,
 'mathscr': 0.11518897669104289,
 'testscr': 0.029182778885538996}
'prediction based on structured model:'
{'readscr': 0.32351923326782916,
 'testscr': 0.3688571477752969,
 'mathscr': 0.37847951309032846}

In this scenario the structured model outperforms the simple model significantly.

Example 3: predicting based on district average income

Let’s try a prediction based on only the district average income.

[30]:
display("prediction based on simple model:")
display({
    key: value for key, value in
    causal_structure_1.evaluate_objective(Evaluator, data=data_test,
                                          inputs=['avginc'], outputs=outputs,
                                          metric="r2").items()
    if key in outputs
})
display("prediction based on structured model:")
display({
    key: value for key, value in
    causal_structure_2.evaluate_objective(Evaluator, data=data_test,
                                          inputs=['avginc'], outputs=outputs,
                                          metric="r2").items()
    if key in outputs
})
'prediction based on simple model:'
{'readscr': 0.2816587802736583,
 'mathscr': 0.2002535289464773,
 'testscr': 0.2320481281685418}
'prediction based on structured model:'
{'readscr': 0.34039792627046295,
 'testscr': 0.3543316153774535,
 'mathscr': 0.3358698650229551}

In this scenario the structured model outperforms the simple model significantly.

To summarize, a causal structure based on realistic dependency assumptions will not necessarily outperform a black-box model with only inputs and outputs. It is however, more robust when making predicions based on incomplete data.