Source code for pints._abc._abc_rejection
#
# ABC Rejection method
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import pints
import numpy as np
[docs]
class RejectionABC(pints.ABCSampler):
r"""
Implements the rejection ABC algorithm as described in [1].
Here is a high-level description of the algorithm:
.. math::
\begin{align}
\theta^* &\sim p(\theta) \\
x &\sim p(x|\theta^*) \\
\textrm{if } s(x) < \textrm{threshold}, \textrm{then} \\
\theta^* \textrm{ is added to list of samples} \\
\end{align}
In other words, the first two steps sample parameters
from the prior distribution :math:`p(\theta)` and then sample
simulated data from the sampling distribution (conditional on
the sampled parameter values), :math:`p(x|\theta^*)`.
In the end, if the error measure between our simulated data and
the original data is within the threshold, we add the sampled
parameters to the list of samples.
References
----------
.. [1] "Approximate Bayesian Computation (ABC) in practice". Katalin
Csillery, Michael G.B. Blum, Oscar E. Gaggiotti, Olivier Francois
(2010) Trends in Ecology & Evolution
https://doi.org/10.1016/j.tree.2010.04.001
"""
def __init__(self, log_prior):
self._log_prior = log_prior
self._threshold = 1
self._xs = None
self._ready_for_tell = False
[docs]
def name(self):
""" See :meth:`pints.ABCSampler.name()`. """
return 'Rejection ABC'
[docs]
def ask(self, n_samples):
""" See :meth:`ABCSampler.ask()`. """
if self._ready_for_tell:
raise RuntimeError('Ask called before tell.')
self._xs = self._log_prior.sample(n_samples)
self._ready_for_tell = True
return self._xs
[docs]
def tell(self, fx):
""" See :meth:`ABCSampler.tell()`. """
if not self._ready_for_tell:
raise RuntimeError('Tell called before ask.')
self._ready_for_tell = False
fx = pints.vector(fx)
accepted = fx < self._threshold
if not np.any(accepted):
return None
else:
return [self._xs[c].tolist() for c, x in enumerate(accepted) if x]
[docs]
def threshold(self):
"""
Returns threshold error distance that determines if a sample is
accepted (if ``error < threshold``).
"""
return self._threshold
[docs]
def set_threshold(self, threshold):
"""
Sets threshold error distance that determines if a sample is accepted
(if ``error < threshold``).
"""
x = float(threshold)
if x <= 0:
raise ValueError('Threshold must be greater than zero.')
self._threshold = threshold