Source code for pints._abc
#
# Sub-module containing ABC inference routines
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import pints
import numpy as np
[docs]
class ABCSampler(pints.Loggable, pints.TunableMethod):
"""
Abstract base class for ABC methods.
All ABC samplers implement the :class:`pints.Loggable` and
:class:`pints.TunableMethod` interfaces.
"""
[docs]
def name(self):
"""
Returns this method's full name.
"""
raise NotImplementedError
[docs]
def ask(self):
"""
Returns a parameter vector sampled from the LogPrior.
"""
raise NotImplementedError
[docs]
def tell(self, x):
"""
Performs an iteration of the ABC algorithm, using the
parameters specified by ask.
Expects to receive x as a sequence of length at least 1.
Returns the accepted parameter values.
"""
raise NotImplementedError
[docs]
class ABCController(object):
"""
Samples from a :class:`pints.LogPrior`.
Properties related to the number of iterations, parallelisation,
threshold, and number of parameters to sample can be set directly on the
``ABCController`` object. Afterwards the ABC routine can be run.
Parameters
----------
error_measure
An error measure to evaluate on a problem, given a forward model,
simulated and observed data, and times
log_prior
A :class:`LogPrior` function from which parameter values are sampled
method
The class of :class:`ABCSampler` to use. If no method is specified,
:class:`RejectionABC` is used.
Example
-------
::
abc = pints.ABCController(error_measure, log_prior)
abc.set_max_iterations(1000)
posterior_estimate = abc.run()
"""
def __init__(self, error_measure, log_prior, method=None):
# Store function
if not isinstance(log_prior, pints.LogPrior):
raise ValueError('Given function must extend pints.LogPrior.')
self._log_prior = log_prior
# Check error_measure
if not isinstance(error_measure, pints.ErrorMeasure):
raise ValueError('Given error_measure must extend '
'pints.ErrorMeasure')
self._error_measure = error_measure
# Check if number of parameters from prior matches that of error
# measure
if self._log_prior.n_parameters() != \
self._error_measure.n_parameters():
raise ValueError('Number of parameters in prior must match number '
'of parameters in error measure.')
# Get number of parameters
self._n_parameters = self._log_prior.n_parameters()
# Set rejection ABC as default method
if method is None:
method = pints.RejectionABC
else:
try:
ok = issubclass(method, ABCSampler)
except TypeError: # Not a class
ok = False
if not ok:
raise ValueError('Given method must extend ABCSampler.')
# Initialisation
# Parallelisation
self._parallel = False
self._n_workers = 1
# Maximum number of iterations as a stopping criterion
self._max_iterations = 10000
# Maximum number of target samples to obtain
# in the estimated posterior
self._n_samples = 500
# The sampler object uses the prior distribution
self._sampler = method(log_prior)
# Logging
self._log_to_screen = True
self._log_filename = None
self._log_csv = False
self.set_log_interval()
[docs]
def set_log_interval(self, iters=20, warm_up=3):
"""
Changes the frequency with which messages are logged.
Parameters
----------
iters
A log message will be shown every ``iters`` iterations.
warm_up
A log message will be shown every iteration, for the first
``warm_up`` iterations.
"""
iters = int(iters)
if iters < 1:
raise ValueError("Interval must be greater than 0.")
warm_up = max(0, int(warm_up))
self._message_interval = iters
self._message_warm_up = warm_up
[docs]
def set_log_to_file(self, filename=None, csv=False):
"""
Enables progress logging to file when a filename is passed in, disables
it if ``filename`` is ``False`` or ``None``.
The argument ``csv`` can be set to ``True`` to write the file in comma
separated value (CSV) format. By default, the file contents will be
similar to the output on screen.
"""
if filename:
self._log_filename = str(filename)
self._log_csv = True if csv else False
else:
self._log_filename = None
self._log_csv = False
[docs]
def set_log_to_screen(self, enabled):
"""
Enables or disables progress logging to screen.
"""
self._log_to_screen = True if enabled else False
[docs]
def max_iterations(self):
"""
Returns the maximum iterations if this stopping criterion is set, or
``None`` if it is not. See :meth:`set_max_iterations()`.
"""
return self._max_iterations
[docs]
def n_samples(self):
"""
Returns the target number of samples to obtain in the estimated
posterior.
"""
return self._n_samples
[docs]
def parallel(self):
"""
Returns the number of parallel worker processes this routine will be
run on, or ``False`` if parallelisation is disabled.
"""
return self._n_workers if self._parallel else False
[docs]
def run(self):
"""
Runs the ABC sampler.
"""
if self._max_iterations is None:
raise ValueError("At least one stopping criterion must be set.")
# Iteration and evaluation counting
iteration = 0
evaluations = 0
accepted_count = 0
# Choose method to evaluate
f = self._error_measure
# Create evaluator
if self._parallel:
n_workers = self._n_workers
evaluator = pints.ParallelEvaluator(f, n_workers=n_workers)
else:
evaluator = pints.SequentialEvaluator(f)
# Set up progress reporting
next_message = 0
# Start logging
logging = self._log_to_screen or self._log_filename
if logging:
if self._log_to_screen:
print('Using ' + str(self._sampler.name()))
if self._parallel:
print('Running in parallel with ' + str(n_workers) +
' worker processess.')
else:
print('Running in sequential mode.')
# Set up logger
logger = pints.Logger()
if not self._log_to_screen:
logger.set_stream(None)
if self._log_filename:
logger.set_filename(self._log_filename, csv=self._log_csv)
# Add fields to log
max_iter_guess = max(self._max_iterations or 0, 10000)
max_eval_guess = max_iter_guess
logger.add_counter('Iter.', max_value=max_iter_guess)
logger.add_counter('Eval.', max_value=max_eval_guess)
logger.add_float('Acceptance rate')
self._sampler._log_init(logger)
# Note: removed units from time field, see
# https://github.com/pints-team/pints/issues/1467
logger.add_time('Time')
# Start sampling
timer = pints.Timer()
running = True
# Specifying the number of samples we want to get
# from the prior at once. It depends on whether we
# are using parallelisation and how many workers
# are being used.
if self._parallel:
n_requested_samples = self._n_workers
else:
n_requested_samples = 1
samples = []
# Sample until we find an acceptable sample
while running:
accepted_vals = None
while accepted_vals is None:
# Get points from prior
xs = self._sampler.ask(n_requested_samples)
# Simulate and get error
fxs = evaluator.evaluate(xs)
evaluations += self._n_workers
# Tell sampler errors and get list of acceptable parameters
accepted_vals = self._sampler.tell(fxs)
accepted_count += len(accepted_vals)
for val in accepted_vals:
samples.append(val)
iteration += 1
# Log progress
if logging and iteration >= next_message:
# Log state
logger.log(iteration, evaluations, (
accepted_count / evaluations))
self._sampler._log_write(logger)
logger.log(timer.time())
# Choose next logging point
if iteration < self._message_warm_up:
next_message = iteration + 1
else:
next_message = self._message_interval * (
1 + iteration // self._message_interval)
if iteration >= self._max_iterations:
running = False
halt_message = ('Halting: Maximum number of iterations ('
+ str(iteration) + ') reached. Only '
+ str(accepted_count) + ' samples were '
+ 'obtained.')
elif accepted_count >= self._n_samples:
running = False
halt_message = ('Halting: Target number of samples ('
+ str(accepted_count) + ') reached.')
# Log final state and show halt message
if logging:
logger.log(iteration, evaluations)
self._sampler._log_write(logger)
logger.log(timer.time())
if self._log_to_screen:
print(halt_message)
samples = np.array(samples)
return samples
[docs]
def log_filename(self):
"""
Returns the path to the controller log, or ``None`` if not set.
"""
return self._log_filename
[docs]
def sampler(self):
"""
Returns the underlying sampler object.
"""
return self._sampler
[docs]
def set_max_iterations(self, iterations=10000):
"""
Adds a stopping criterion, allowing the routine to halt after the
given number of ``iterations``.
This criterion is enabled by default. To disable it, use
``set_max_iterations(None)``.
"""
if iterations is not None:
iterations = int(iterations)
if iterations < 0:
raise ValueError(
'Maximum number of iterations cannot be negative.')
self._max_iterations = iterations
[docs]
def set_n_samples(self, n_samples=500):
"""
Sets a target number of samples
"""
self._n_samples = n_samples
[docs]
def set_parallel(self, parallel=False):
"""
Enables/disables parallel evaluation.
If ``parallel=True``, the method will run using a number of worker
processes equal to the detected cpu core count. The number of workers
can be set explicitly by setting ``parallel`` to an integer greater
than 0.
Parallelisation can be disabled by setting ``parallel`` to ``0`` or
``False``.
"""
if parallel is True:
self._n_workers = pints.ParallelEvaluator.cpu_count()
self._parallel = True
elif parallel >= 1:
self._parallel = True
self._n_workers = int(parallel)
else:
self._parallel = False
self._n_workers = 1