{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Encoding causal structure" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Imagine having a process with some causal structure\n", "\n", "The causal structure we assume simply encodes that\n", " - y1 depends on settings x1 and x2\n", " - y2 depends on settings x2 and x3" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import halerium.core as hal\n", "from halerium.core.regression import connect_via_regression\n", "\n", "g = hal.Graph(\"g\")\n", "with g:\n", " hal.Entity(\"settings\")\n", " with settings:\n", " hal.Variable(\"x1\", mean=0., variance=1.)\n", " hal.Variable(\"x2\", mean=0., variance=1.)\n", " hal.Variable(\"x3\", mean=0., variance=1.)\n", " \n", " hal.Graph(\"substructure1\")\n", " with substructure1:\n", " with inputs:\n", " hal.Entity(\"settings\")\n", " with settings:\n", " hal.Variable(\"x1\")\n", " hal.Variable(\"x2\")\n", " \n", " with outputs:\n", " hal.Entity(\"results\")\n", " with results:\n", " hal.Variable(\"y1\", variance=0.1)\n", " \n", " connect_via_regression(name_prefix=\"reg\",\n", " inputs=[inputs.settings.x1, inputs.settings.x2],\n", " outputs=[outputs.results.y1])\n", " \n", " hal.Graph(\"substructure2\")\n", " with substructure2:\n", " with inputs:\n", " hal.Entity(\"settings\")\n", " with settings:\n", " hal.Variable(\"x2\")\n", " hal.Variable(\"x3\")\n", " \n", " with outputs:\n", " hal.Entity(\"results\")\n", " with results:\n", " hal.Variable(\"y2\", variance=0.1)\n", " \n", " connect_via_regression(name_prefix=\"reg\",\n", " inputs=[inputs.settings.x2, inputs.settings.x3],\n", " outputs=[outputs.results.y2])\n", " \n", " hal.link(settings, substructure1.inputs.settings)\n", " hal.link(settings, substructure2.inputs.settings)\n", "\n", "# use the hal.show function to display the graph in the online platform\n", "#hal.show(g)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Let us generate artificial data to test this.\n", "\n", "The important thing is that the settings in the past were not chosen randomly,\n", "but they were managed in some manner that did not have a machine learning\n", "use case in mind." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### generate training data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "real_matrix1 = np.array([1., -1])\n", "real_intercept1 = 0.\n", "real_matrix2 = np.array([-1, 1.])\n", "real_intercept2 = 0.\n", "\n", "past_settings = [[-1., -1., -1.]]*10\n", "past_settings += [[0., -1., 0.]]*10\n", "past_settings += [[1., 1., 1.]]*10\n", "past_settings = np.array(past_settings)\n", "\n", "past_results = hal.get_generative_model(g, data={g.settings.x1: past_settings[:, 0],\n", " g.settings.x2: past_settings[:, 1],\n", " g.settings.x3: past_settings[:, 2],\n", " g.substructure1.reg_y1.location.slope: real_matrix1,\n", " g.substructure1.reg_y1.location.intercept: real_intercept1,\n", " g.substructure2.reg_y2.location.slope: real_matrix2,\n", " g.substructure2.reg_y2.location.intercept: real_intercept2}).get_samples(\n", " [g.substructure1.outputs.results.y1, g.substructure2.outputs.results.y2])\n", "past_results = np.array(past_results)[:, 0, :].T" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### generate test data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "future_settings = ([[-1., 0., 0.]]*10 +\n", " [[0.5, 0.5, -1.]]*10)\n", "future_settings = np.array(future_settings)\n", "\n", "future_results = hal.get_generative_model(g, data={g.settings.x1: future_settings[:, 0],\n", " g.settings.x2: future_settings[:, 1],\n", " g.settings.x3: future_settings[:, 2],\n", " g.substructure1.reg_y1.location.slope: real_matrix1,\n", " g.substructure1.reg_y1.location.intercept: real_intercept1,\n", " g.substructure2.reg_y2.location.slope: real_matrix2,\n", " g.substructure2.reg_y2.location.intercept: real_intercept2}).get_samples(\n", " [g.substructure1.outputs.results.y1, g.substructure2.outputs.results.y2])\n", "future_results = np.array(future_results)[:, 0, :].T" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### train a black box model and apply it on test data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LinearRegression\n", "lr = LinearRegression()\n", "lr.fit(past_settings, past_results)\n", "black_box_prediction_past = lr.predict(past_settings)\n", "black_box_prediction_future = lr.predict(future_settings)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### train the causal model and apply it on test data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "causal_model = hal.get_posterior_model(g, data={g.settings.x1: past_settings[:, 0],\n", " g.settings.x2: past_settings[:, 1],\n", " g.settings.x3: past_settings[:, 2],\n", " g.substructure1.outputs.results.y1: past_results[:, 0],\n", " g.substructure2.outputs.results.y2: past_results[:, 1]})\n", "trained_causal_graph = causal_model.get_posterior_graph()\n", "\n", "causal_prediction_past = hal.get_generative_model(\n", " trained_causal_graph,\n", " data={g.settings.x1: past_settings[:, 0],\n", " g.settings.x2: past_settings[:, 1],\n", " g.settings.x3: past_settings[:, 2]}).get_means(\n", " [g.substructure1.outputs.results.y1, g.substructure2.outputs.results.y2])\n", "causal_prediction_past = np.array(causal_prediction_past).T\n", "\n", "causal_prediction_future = hal.get_generative_model(\n", " trained_causal_graph,\n", " data={g.settings.x1: future_settings[:, 0],\n", " g.settings.x2: future_settings[:, 1],\n", " g.settings.x3: future_settings[:, 2]}).get_means(\n", " [g.substructure1.outputs.results.y1, g.substructure2.outputs.results.y2])\n", "causal_prediction_future = np.array(causal_prediction_future).T" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Compare the performance" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "black box on training data: 0.8700155151258816\n", "causal model on training data: 0.8646083422454567\n", "black box on test data: 0.0830037653624115\n", "causal box on test data: 0.8269131649522296\n" ] } ], "source": [ "norm = np.var(np.append(past_results, future_results))\n", "black_box_performance_past = 1-np.mean((black_box_prediction_past-past_results)**2) / norm\n", "causal_performance_past = 1-np.mean((causal_prediction_past-past_results)**2) / norm\n", "\n", "black_box_performance_future = 1-np.mean((black_box_prediction_future-future_results)**2) / norm\n", "causal_performance_future = 1-np.mean((causal_prediction_future-future_results)**2) / norm\n", "\n", "print(\"black box on training data:\", black_box_performance_past)\n", "print(\"causal model on training data:\", causal_performance_past)\n", "print(\"black box on test data:\", black_box_performance_future)\n", "print(\"causal box on test data:\", causal_performance_future)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### visualization" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from plots import plot_compare_causal_and_black_box\n", "from IPython.display import Image\n", "\n", "Image(plot_compare_causal_and_black_box(black_box_performance_past,\n", " black_box_performance_future,\n", " causal_performance_past,\n", " causal_performance_future))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }