Source code for pints._sample_initial_points
#
# Defines method for initialising points for sampling and optimising
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import pints
import numpy as np
[docs]def sample_initial_points(function, n_points, random_sampler=None,
boundaries=None, max_tries=50, parallel=False,
n_workers=None):
"""
Samples ``n_points`` parameter values to use as starting points in a
sampling or optimisation routine on the given ``function``.
How the initial points are determined depends on the arguments supplied. In
order of precedence:
1. If a method ``random_sampler`` is provided then this will be used to
draw the random samples.
2. If no sampler method is given but ``function`` is a
:class:`LogPosterior` then the method ``function.log_prior().sample()``
will be used.
3. If no sampler method is supplied and ``function`` is not a
:class:`LogPosterior` and if ``boundaries`` are provided then the method
``boundaries.sample()`` will be used to draw samples.
A ``ValueError`` is raised if none of the above options are available.
Each sample ``x`` is tested to ensure that ``function(x)`` returns a finite
result within ``boundaries`` if these are supplied. If not, a new sample
will be drawn. This is repeated at most ``max_tries`` times, after which an
error is raised.
Parameters
----------
function :
A :class:`pints.ErrorMeasure` or a :class:`pints.LogPDF` that
evaluates points in the parameter space. If the latter, it is optional
that ``function`` be of type :class:`LogPosterior`.
n_points : int
The number of initial values to generate.
random_sampler :
A function that when called returns draws from a probability
distribution of the same dimensionality as ``function``. The only
argument to this function should be an integer specifying the number of
draws.
boundaries :
An optional set of boundaries on the parameter space of class
:class:`pints.Boundaries`.
max_tries : int
Number of attempts to find a finite initial value across all
``n_points``. By default this is 50 per point.
parallel : bool
Whether to evaluate ``function`` in parallel (defaults to False).
n_workers : int
Number of workers on which to run parallel evaluation.
"""
is_not_logpdf = not isinstance(function, pints.LogPDF)
is_not_errormeasure = not isinstance(function, pints.ErrorMeasure)
# Check function
if is_not_logpdf and is_not_errormeasure:
raise ValueError(
'function must be either pints.LogPDF or pints.ErrorMeasure.')
# Check boundaries
if boundaries is not None:
if not isinstance(boundaries, pints.Boundaries):
raise ValueError('boundaries must be a pints.Boundaries object.')
elif boundaries.n_parameters() != function.n_parameters():
raise ValueError('boundaries must match dimension of function.')
# Check or set random sampler
if random_sampler is None:
if isinstance(function, pints.LogPosterior):
random_sampler = function.log_prior().sample
elif boundaries is not None:
random_sampler = boundaries.sample
else:
raise ValueError(
'If function is not a pints.LogPosterior and no boundaries'
' are given then a random_sampler must be supplied.')
elif not callable(random_sampler):
raise ValueError(
'random_sampler must be a callable function, if supplied.')
# Check number of initial points
if n_points < 1:
raise ValueError('Number of initial points must be 1 or more.')
# Set up parallelisation
if parallel:
n_workers = min(pints.ParallelEvaluator.cpu_count(), n_points)
evaluator = pints.ParallelEvaluator(function, n_workers=n_workers)
else:
evaluator = pints.SequentialEvaluator(function)
# Go!
x0 = []
n_tries = 0
while len(x0) < n_points and n_tries < max_tries:
xs = random_sampler(n_points - len(x0))
fxs = evaluator.evaluate(xs)
for i, x in enumerate(xs):
fx = fxs[i]
if np.isfinite(fx):
if boundaries is None or boundaries.check(x):
x0.append(x)
n_tries += 1
if len(x0) < n_points:
raise RuntimeError(
'Initialisation failed since function not finite or within ' +
'bounds at initial points after ' + str(max_tries) + ' attempts.')
return x0