The InterventionPredictor class#
Aliases#
halerium.InterventionPredictor
halerium.core.InterventionPredictor
halerium.core.objectives.InterventionPredictor
- class InterventionPredictor(graph, data=None, interventions=None, method='MAPFisher', compiler=None, model_args=None, solver_args=None, measure='mean', n_samples=100, name='InterventionPredictor', description=None, copy_graph=True)#
Class for computing predictions for interventions.
Class for computing predictions for interventions.
- Parameters:
graph (halerium.core.Graph) – The graph for which to compute intervention predictions.
data (dict, halerium.core.DataLinker, optional) – The input data for the intervention prediction. Either dictionary with variables as keys and data arrays as values, or a DataLinker holding links to the variables in graph. The default is None.
interventions (dict, halerium.core.DataLinker, optional) – The intervention data for the prediction. Either dictionary with variables as keys and data arrays as values, or a DataLinker holding links to the variables in graph. Note that values in interventions override values in data. The default is None.
method ({'Forward', 'MAPFisher', 'MAP', 'ADVI', or 'MGVI'}, optional) – The solving method. According to the chosen method, either a ForwardModel , MAPFisherModel, MAPModel, ADVIModel, or MGVIModel instance is created. See the corresponding model class for further information. The default is ‘MAPFisher’.
compiler (halerium.core.compiler.compiler_base.CompilerBase, optional) – Which compiler to use to create the numerical arrays of the model. If None, a TFCompiler instance is created. The default is None.
model_args (dict, optional) – Model arguments that depend on the specified method. See the corresponding model classes for further information.
solver_args (dict, optional) – The arguments to pass to the model’s solver function. The default is None.
measure (str, callable, tuple, list, dict of str, dict of callable, optional) – The statistical property to predict. For a str, recognized values are: ‘mean’, ‘variance’, ‘standard_deviation’. If a callable, it should take an ndarray as first argument, and as keyword argument the ‘axis’ along which the sample is ordered (see, e.g., ‘np.mean’). If a tuple, list, or dictionary, its values should be a tuple, list, dictionary, a recognized str, or a callable.
n_samples (int, optional) – The number of examples to compute. The default is 100.
name (str, optional) – The name of the objective.
description (str, optional) – The description of the objective.
copy_graph (bool, optional) – Whether the objective should make a copy of the graph for its own use, or just keep the graph itself as attribute. Users should leave this set to the default True, unless they are certain that the graph alterations by the objective don’t have an adverse impact on other parts using the graph, and that the graph won’t be altered by the user or other code.
- Raises:
ValueError – If the supplied method is unknown an error is raised.
- __call__(fetches=None, n_samples=None)#
Predict values for given fetches.
- Parameters:
fetches (halerium.core.scope.Scopee, dict, list or tuple, optional) – The graph elements to predict values for. If no fetches are provided, the default is the graph itself and all its subgraphs, entities, and (static) variables. Note, however, values are only computed for (static) variables of the graph. For other graph elements, the prediction returns None.
n_samples (int, optional) – The number of samples to estimate the prediction from. The default is to use the number provided at initialization.
- Returns:
The predictions for the graph elements. Note that for elements that are not (Static)Variables, None will be returned.
- Return type:
prediction
Examples
>>> from halerium.core import Graph, Variable >>> from halerium import Predictor >>> >>> g = Graph("g") >>> with g: >>> x = Variable("x", mean=0, variance=1) >>> y = Variable("y", mean=x + 1, variance=1) >>> >>> predictor = Predictor(g, {g.x: [0, 1, 2]}) >>> >>> predictor(g.y) array([1., 2., 3.])
- dump_dict(value_postprocessor=None)#
Dump a dict with information on the objective.
The dict returned contains the name, description and the values resulting from a call of the objective. Additional keys included are used by the GUI for appropriately displaying the results of the objective.
- Parameters:
value_postprocessor (optional) – A function to apply to the values returned by the call of the objective. The default is None, in which case no post-processing is done.
- Returns:
result – A dictionary containing the name, description, etc. of the objective.
- Return type:
dict