The Trainer class#

Aliases#

halerium.core.Trainer
halerium.core.model.Trainer
class Trainer(graph, data=None, method='MAPFisher', compiler=None, return_solved=True, model_args=None, solver_args=None, trained_graph_name=None, n_samples=100)#

Class for training.

Class for training.

Parameters:
  • graph (halerium.core.Graph) – A Graph instance representing the model to train.

  • data (dict, halerium.core.DataLinker, optional) – The data for the training. Either dictionary with variables as keys and data arrays as values, or a DataLinker holding links to the variables in graph. The default is None.

  • method ({'MAPFisher', 'MAP', 'ADVI', 'MGVI', or 'MCMC'}, optional) – The solving method. According to the chosen method either a MAPFisherModel, MAPModel, ADVIModel, MGVIModel, or MCMCModel instance is created. See the corresponding model class for further information. The default is ‘MAPFisher’.

  • compiler (halerium.core.compiler.compiler_base.CompilerBase, optional) – Which compiler to use to create the numerical arrays of the model. If None a TFCompiler instance is created. The default is None.

  • return_solved (bool) – Whether to return the model trained. The default is True.

  • model_args (dict, optional) – Model arguments that depend on the specified method. See the corresponding model classes for further information.

  • solver_args (dict, optional) – The arguments to pass to the model’s solver function. The default is None.

  • trained_graph_name (str, optional) – The name to give to the trained graph returned by calling the instance. The default is None.

  • n_samples (int, optional) – The number of examples to estimate the posterior distribution from. The default is 1000.

Raises:

ValueError – If the supplied method is unknown an error is raised.

__call__(trained_graph_name=None, n_samples=None)#

Create posterior graph from trained model.

Parameters:
  • trained_graph_name (str, optional) – The name to give to the posterior graph returned.

  • n_samples (int, optional) – The number of samples to estimate the posterior distributions from.

Returns:

trained_graph – The posterior graph representing the trained model.

Return type:

halerium.core.Graph

Examples

>>> from halerium.core import Graph, Variable, StaticVariable
>>> from halerium.core.model.train import Trainer
>>>
>>> g = Graph("g")
>>> with g:
>>>     x = Variable("x", mean=0, variance=1)
>>>     s = StaticVariable('s', mean=0, variance=1)
>>>     y = Variable("y", mean= x * s, variance=1)
>>>
>>> trainer = Trainer(g, data = {g.x : [0,1], g.y: [0,2]})
>>>
>>> trained_graph = trainer()
fit(**kwargs)#

Fits the model to the data (alias for solve).

Call this method if the model had not been trained yet, or if any further training is desired.

Note that the model is already fit to the data during initialization unless return_unsolved was set to True. Additional calls to fit then might improve the fit to that data. The method cannot be used to fit to any new data.

Parameters:

kwargs – Any keyword arguments to pass to the minimizer.

Return type:

None.