The MGVIModel class#

Aliases#

halerium.core.model.MGVIModel
class MGVIModel(graph, data=None, compiler=None, initial_values=None, initial_source_values=None, random_initial_value_scale=0.0001, sample_caching=True, sample_symmetry=True, fisher_diagonal_preconditioning=False, solver='L-BFGS', copy_graph=True, model_graph_options=None)#

Model providing MGVI samples.

Model providing MGVI samples.

The model implements the Metric Gaussian Variational Inference (MGVI) algorithm, see https://arxiv.org/abs/1901.11033 MGVI approximates the posterior distribution by a multivariate Gaussian distribution with the Fisher metric serving as the covariance. This variational approximation is performed in a transformed parameter space in which the a-priori distribution is Gaussian. The Gaussian samples are generated in this transformed parameter space, which means that the sampled parameters do not have to follow a Gaussian distribution themselves.

Parameters:
  • graph (halerium.core.Graph) – The graph of the model.

  • data (halerium.core.DataLinker, dict, optional) – The data linker or dict containing data constraining the model. The default is None.

  • compiler (optional) – The compiler instance or class for compiling the model. The default is None, in which case a Tensorflow compiler is used.

  • initial_values (dict, optional) – A dictionary containing (static) variables in the graph as keys and their initial values as values. The model then tries to convert these into appropriate initial values for the solver.

  • initial_source_values (dict, optional) – A dictionary containing (static) variables in the graph as keys and initial value for their source as values. In addition, strings are accepted as keys, but are ignored unless the model can interpret these. This attribute may be used to pass initial values for trainable model parameters obtained from training of earlier models with the same (or a sufficiently similar) graph. Other usage is not recommended (unless users know exactly how the model handles its sources and their initial values).

  • random_initial_value_scale (float, None, optional) – The scale of any randomly drawn initial values (relative to the scale of the respective variables as given by the graph). The default is 1e-4.

  • sample_caching (bool, optional) – Whether generated samples are cached or not. Caching significantly speeds up the evaluation of subsequent calls to e.g. get_means. The default is True.

  • sample_symmetry (bool, optional) – Whether to mirror generated samples. The mirror of a sample is an equally likely sample from the Gaussian approximation. If True mirror samples are used to get the required amount of samples in half the time. The default is True.

  • fisher_diagonal_preconditioning (bool, optional) – If True the diagonal of the Fisher matrix is constructed and used for preconditioning of the matrix inversion. This can speed up the convergence but also increases memory consumption. The default is False.

  • solver (str, None,optional) – The algorithm used to solve the model. Current choices are ‘L-BFGS’ (limited memory BFGS) and ‘NGD’ (natural gradient descent). The default is ‘NGD’.

  • copy_graph (bool, optional) – Whether the model should make a copy of the graph for its own use, or just keep the graph itself as attribute. Users should leave this set to the default True, unless they are certain that the graph won’t be altered by the user or other code. Such changes to a graph a model holds directly (i.e. not a copy) makes that model inconsistent and likely causes errors. The default is True

  • model_graph_options (dict, optional) – The options for creating the model graph. The default is None.

property apply_fisher_metric#
apply_to_samples(fetches, function, n_samples)#

Draw samples and apply a function to them.

Parameters:
  • fetches – The variables to generate sample data for.

  • function (callable) – The function to apply to the sample data.

  • n_samples (int) – The number of samples to draw from the model.

Returns:

The result of applying the function to the sampled data.

Return type:

result

assert_is_trained()#

Check if model is trained.

Return type:

None.

Raises:

RuntimeWarning – If model is not trained.

get_example(fetches)#

Draw an example from the model.

Parameters:

fetches – The variables to generate example values for.

Returns:

The example data.

Return type:

example

get_means(fetches, n_samples=100)#

Estimate mean values.

Parameters:
  • fetches – The variables to estimate mean values for.

  • n_samples (int) – The number of samples to estimate the means from.

Returns:

The estimated means of the variables.

Return type:

means

get_posterior_graph(name=None, n_samples=100)#

Create posterior graph from trained model.

Parameters:
  • name (str) – The name to give to the posterior graph.

  • n_samples (int) – The number of samples to estimate the posterior distributions from.

Returns:

post_graph – The posterior graph.

Return type:

halerium.core.Graph

get_samples(fetches, n_samples=1)#

Draw samples from the model.

Parameters:
  • fetches – The variables to generate sample data for.

  • n_samples (int) – The number of examples to draw from the model.

Returns:

The sampled data.

Return type:

samples

get_source_values(return_all=False)#

Get source values.

Prior to any training, these are just any initial values provided by the caller to the model upon construction. After training, these are the final values of the sources at the end of the training.

Parameters:

return_all (bool) – Whether to return all, or just the source values for static variables (default).

Returns:

The source values.

Return type:

source_values

get_standard_deviations(fetches, n_samples=100)#

Estimate standard deviations.

Parameters:
  • fetches – The variables to estimate standard deviations for.

  • n_samples (int) – The number of samples to estimate the standard deviations from.

Returns:

The estimated standard deviations of the variables.

Return type:

standard_deviations

get_variances(fetches, n_samples=100)#

Estimate variances.

Parameters:
  • fetches – The variables to estimate variances for.

  • n_samples (int) – The number of samples to estimate the variances from.

Returns:

The estimated variances of the variables.

Return type:

variances

property global_source_value#
property is_trained#

Whether the model has been trained.

property model_graph#

The model graph (don’t modify it yourself).

recognized_solvers = ('L-BFGS', 'NGD')#
reset_sample_cache()#

Reset sample cache.

Return type:

None.

solve(tol=0.0001, n_samples=16, max_iter=100, solver=None)#

Solve the model.

Parameters:
  • tol (float, optional) – The tolerance decides when to stop the solver iterations. The default is 1.e-3

  • n_samples (int) – The amount of samples used for the last iteration(s). The default is 16.

  • max_iter (int) – (deprecated) The maximum amount of iterations.

  • solver (str, optional) – The solver to be used. Either “NGD” for natural gradient descent or “L-BFGS” for limited memory BFGS.

Return type:

None

update_lbfgs(n_samples=8, pgtol=0.0001, max_iter=5000)#

update the source with a L-BFGS minimization.

Parameters:
  • n_samples (int, optional) – The amount of samples to estimate the D_KL. The default is 8.

  • pgtol (float, optional) – the projected gradient tolerance of the L-BFGS algorithm. The default is 1.e-3

  • max_iter (int, optional) – The maximum amount of iterations. The default is 5000.

update_ngd(n_samples=8, tol=0.0001, n_sub_iter=3, max_iter_line_search=100)#

update the model using natural gradient descent steps.

Parameters:
  • n_samples (int, optional) – The amount of samples to estimate the D_KL. The default is 8.

  • tol (float, optional) – The gradient tolerance to stop sub-iterations. The default is 1e-4.

  • n_sub_iter (int, optional) – The amount of natural gradient steps to perform with the generated samples. The default is 3.

  • max_iter_line_search (int, optional) – The maximum amount of iterations for the line search The default is 100.