Source code for pints._mcmc._slice_stepout

# -*- coding: utf-8 -*-
#
# Slice Sampling with Stepout MCMC Method
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import pints
import numpy as np


[docs] class SliceStepoutMCMC(pints.SingleChainMCMC): r""" Implements Slice Sampling with Stepout, as described in [1]_. This is a univariate method, which is applied in a Slice-Sampling-within-Gibbs framework to allow MCMC sampling from multivariate models. Generates samples by sampling uniformly from the volume underneath the posterior (``f``). It does so by introducing an auxiliary variable (``y``) and by definying a Markov chain. If the distribution is univariate, sampling follows: 1. Calculate the PDF (:math:`f(x0)`) of the current sample (:math:`x0`). 2. Draw a real value (:math:`y`) uniformly from :math`(0, f(x0))`, defining a horizontal 'slice' :math:`S = {x: y < f (x)}`. Note that :math:`x0` is always within :math:`S`. 3. Find an interval (:math:`I = (L, R)`) around :math:`x0` that contains all, or much, of the slice. 4. Draw a new point (:math:`x1`) from the part of the slice within this interval. If the distribution is multivariate, we apply the univariate algorithm to each variable in turn, where the other variables are set at their current values. This implementation uses the "Stepout" method to estimate the interval :math:`I = (L, R)`, as described in [1] Fig. 3. pp.715 and consists of the following steps: 1. :math:`U \sim uniform(0, 1)` 2. :math:`L = x_0 - wU` 3. :math:`R = L + w` 4. :math:`V \sim uniform(0, 1)` 5. :math:`J = floor(mV)` 6. :math:`K = (m - 1) - J` 7. while :math:`J > 0` and :math:`y < f(L), L = L - w, J = J - 1` 8. while :math:`K > 0` and :math:`y < f(R), R = R + w, K = K - 1` Intuitively, the interval ``I`` is estimated by expanding the initial interval by a width ``w`` in each direction until both edges fall outside the slice, or until a pre-determined limit is reached. The parameters ``m`` (an integer, which determines the limit of slice size) and ``w`` (the estimate of typical slice width) are hyperparameters. To sample from the interval :math:`I = (L, R)`, such that the sample ``x`` satisfies :math:`y < f(x)`, we use the "Shrinkage" procedure, which reduces the size of the interval after rejecting a trial point, as defined in [1] Fig. 5. pp.716. This algorithm consists of the following steps: 1. :math:`\bar{L} = L` and :math:`\bar{R} = R` 2. Repeat: a. :math:`U \sim uniform(0, 1)` b. :math:`x_1 = \bar{L} + U (\bar{R} - \bar{L})` c. if :math:`y < f(x_1)` accept :math:`x_1` and exit loop, else: if :math:`x_1 < x_0`, :math:`\bar{L} = x_1` else :math:`\bar{R} = x_1` Intuitively, we uniformly sample a trial point from the interval ``I``, and subsequently shrink the interval each time a trial point is rejected. The following implementation includes the possibility of carrying out "overrelaxed" slice sampling steps, as described in [1] pp. 726. Overrelaxed steps increase sampling efficiency in highly correlated unimodal distributions by suppressing the random walk behaviour of single-variable slice sampling: each variable is still updated in turn, but rather than drawing a new value for a variable from its conditional distribution independently of the current value, the new value is instead chosen to be on the opposite side of the mode from the current value. The interval ``I`` is still calculated via Stepout, and the edges ``l,r`` are used to estimate the slice endpoints via bisection. To obtain a full sampling scheme, overrelaxed updates are alternated with normal Stepout updates. To obtain the full benefits of overrelaxation, [1] suggests to set almost every update to being overrelaxed and to set the limit ``m`` for finding ``I`` to infinity. The algorithm consists of the following steps: 1. :math:`\bar{L} = L, \bar{R} = R, \bar{w} = w, \bar{a} = a` 2. while :math:`R - L < 1.1 * w`: a. :math:`M = (\bar{L} + \bar{R})/ 2` b. if :math:`\bar{a} = 0 ` or :math:`y < f(M)`, exit loop c. if :math:`x_0 > M`, :math:`\bar{L} = M` else, :math:`\bar{R} = M` d. :math:`\bar{a} = \bar{a} - 1` e. :math:`\bar{w} = \bar{w} / 2` 3. :math:`\hat{L} = \bar{L}, \hat{R} = \bar{R}` 4. while :math:`\bar{a} > 0`: a. :math:`\bar{a} = \bar{a} - 1` b. :math:`\bar{w} = \bar{w} \ 2` c. if :math:`y >= f(\hat{L} + \bar{w})`, then :math:`\hat{L} = \hat{L} + \bar{w}` d. if :math:`y >= f(\hat{R} - \bar{w})`, then :math:`\hat{R} = \hat{R} - \bar{W}` 5. :math:`x_1 = \hat{L} + \hat{R} - x_0` 6. if :math:`x_1 < \bar{L}` or :math:`x_1 >= \bar{R}` or :math:`y >= f(x_1)`, then :math:`x_1 = x_0` The probability of pursuing an overrelaxed step and the number of bisection iterations are hyperparameters. To avoid floating-point underflow, we implement the suggestion advanced in [1]_ pp.712. We use the log pdf of the un-normalised posterior (:math:`g(x) = log(f(x))`) instead of :math:`f(x)`. In doing so, we use an auxiliary variable :math:`z = log(y) = g(x0) - \epsilon`, where :math:`\epsilon \sim \text{exp}(1)` and define the slice as :math:`S = {x : z < g(x)}`. Extends :class:`SingleChainMCMC`. References ---------- .. [1] Neal, R.M., 2003. "Slice sampling". The annals of statistics, 31(3), pp.705-767. https://doi.org/10.1214/aos/1056562461 """ def __init__(self, x0, sigma0=None): super(SliceStepoutMCMC, self).__init__(x0, sigma0) # Set initial state self._x0 = np.asarray(x0, dtype=float) self._running = False self._ready_for_tell = False self._current = None self._current_log_y = None self._temporary_log_pdf = None self._proposed = None self._overrelaxed_step = False # Default initial interval width ``w`` used in the Stepout procedure # to expand the interval self._w = np.abs(self._x0) self._w[self._w == 0] = 1 self._w = 0.1 * self._w # Default integer limiting the size of the interval to ``m * w``` self._m = 50 # Flag to initialise the expansion of the interval ``I=(L,R)`` self._first_expansion = False # Flag indicating whether the interval expansion is concluded self._interval_found = False # Number of steps used for expanding the interval ``I`` self._j = None self._k = None # Flags used to calculate log_pdf of initial interval edges ``l,r``` self._init_left = False self._init_right = False # Edges of the interval ``I`` self._l = None self._r = None # Parameter values at interval edges self._temp_l = None self._temp_r = None # Log_pdf of interval edges self._fx_l = None self._fx_r = None # Flags to indicate the interval edge to update self._set_l = False self._set_r = False # Index of parameter ``xi``` we are updating of the sample # ``x = (x1,...,xn)`` self._active_param_index = 0 # Probability of overrelaxed step self._prob_overrelaxed = 0 # Interval edges used in overrelaxed step self._l_bar = None self._r_bar = None self._l_hat = None self._r_hat = None self._l_bisection = None self._r_bisection = None self._temp_l_bisection = None self._temp_r_bisection = None self._set_l_bisection = False self._set_r_bisection = False # Integer limiting overrelaxation endpoint accuracy to ``2^(-a) * w`` self._a = 10 # Interval width ``w_bar`` used in the overrelaxation step self._w_bar = None # Mid-point of overrelaxed interval self._mid = None self._temp_mid = None self._fx_mid = None # Flags used for overrelaxed step self._init_overrelaxation = False self._init_narrowing = False self._init_bisection = False self._bisection = False
[docs] def ask(self): """ See :meth:`SingleChainMCMC.ask()`. """ # Check ask/tell pattern if self._ready_for_tell: raise RuntimeError('Ask() called when expecting call to tell().') # Initialise on first call if not self._running: self._running = True # Very first iteration if self._current is None: # Ask for the log pdf of x0 self._ready_for_tell = True return np.array(self._x0, copy=True) # Initialise the expansion of interval ``I=(l,r)`` if self._first_expansion: # Set initial values for l and r u = np.random.uniform() self._l = (self._proposed[self._active_param_index] - self._w[self._active_param_index] * u) self._r = self._l + self._w[self._active_param_index] # Set maximum number of steps for expansion to the left (j) # and right (k) v = np.random.uniform() self._j = np.floor(self._m * v) self._k = (self._m - 1) - self._j # Initialise arrays used for calculating the log_pdf of the edges self._temp_l = np.array(self._proposed, copy=True) self._temp_r = np.array(self._proposed, copy=True) self._temp_l[self._active_param_index] = self._l self._temp_r[self._active_param_index] = self._r # Set flags to calculate log_pdf of ``l,r`` self._init_left = True self._init_right = True self._first_expansion = False # Ask for log_pdf of initial edges ``l,r``` if self._init_left: self._ready_for_tell = True return np.array(self._temp_l, copy=True) if self._init_right: self._ready_for_tell = True return np.array(self._temp_r, copy=True) # Expand the interval ``I``` until edges ``l,r`` are outside the slice # or we have reached limit of expansion steps # Check whether we can expand to the left if self._j > 0 and self._current_log_y < self._fx_l: # Set flag to indicate that we are updating the left edge self._set_l = True # Expand interval to the left self._l -= self._w[self._active_param_index] self._temp_l[self._active_param_index] = self._l self._j -= 1 # Ask for log pdf of the updated left edge self._ready_for_tell = True return np.array(self._temp_l, copy=True) # Reset flag now that we have finished updating the left edge self._set_l = False # Check whether we can expand to the right if self._k > 0 and self._current_log_y < self._fx_r: # Set flag to indicate that we are updating the right edge self._set_r = True # Expand interval to the right self._r += self._w[self._active_param_index] self._temp_r[self._active_param_index] = self._r self._k -= 1 # Ask for log pdf of the updated right edge self._ready_for_tell = True return np.array(self._temp_r, copy=True) # Reset flag now that we have finished updating the right edge self._set_r = False # Now that we have expanded the interval, set flag self._interval_found = True # Overrelaxed step if self._overrelaxed_step: # Initialise variables for overrelaxed step if self._init_overrelaxation: self._l_bar = self._l self._r_bar = self._r self._w_bar = self._w[self._active_param_index] self._a_bar = self._a self._temp_mid = np.array(self._proposed, copy=True) self._init_overrelaxation = False self._init_narrowing = True self._init_bisection = True # If interval is of size ``w``, narrow it until mid-point is # within the slice if (((self._r - self._l) < 1.1 * self._w[self._active_param_index]) and self._init_narrowing): # Ask for log pdf of interval mid point self._mid = (self._l_bar + self._r_bar) / 2 self._temp_mid[self._active_param_index] = self._mid self._ready_for_tell = True return np.array(self._temp_mid, copy=True) # Initialise endpoints for bisection if self._init_bisection: self._l_hat = self._l_bar self._r_hat = self._r_bar self._init_bisection = False # Apply bisection to endpoint edges if self._a_bar > 0: # Prepare bisection if self._bisection: self._w_bar = self._w_bar / 2 self._temp_l_bisection = np.array(self._proposed, copy=True) self._temp_r_bisection = np.array(self._proposed, copy=True) self._set_l_bisection = True self._set_r_bisection = True self._bisection = False # Apply bisection to left edge if self._set_l_bisection: self._l_bisection = (self._l_hat + self._w_bar) self._temp_l_bisection[self._active_param_index] = ( self._l_bisection) self._ready_for_tell = True return np.array(self._temp_l_bisection, copy=True) # Apply bisection to right edge if self._set_r_bisection: self._r_bisection = (self._r_hat - self._w_bar) self._temp_r_bisection[self._active_param_index] = ( self._r_bisection) self._ready_for_tell = True return np.array(self._temp_r_bisection, copy=True) # Find candidate point by flipping from the current point to # the opposide side self._proposed[self._active_param_index] = ( self._l_hat + self._r_hat - self._current[self._active_param_index]) self._ready_for_tell = True return np.array(self._proposed, copy=True) else: # Sample new trial point by sampling uniformly from the # interval ``I=(l,r)`` u = np.random.uniform() self._proposed[self._active_param_index] = \ self._l + u * (self._r - self._l) # Send trial point for checks self._ready_for_tell = True return np.array(self._proposed, copy=True)
[docs] def bisection_steps(self): """ Returns integer limit overrelaxation endpoint accuracy to ``2^(-bisection steps) * width``. """ return self._a
[docs] def current_slice_height(self): """ Returns current height value used to define the current slice. """ return self._current_log_y
[docs] def expansion_steps(self): """ Returns integer used for limiting interval expansion. """ return self._m
[docs] def prob_overrelaxed(self): """ Returns probability of carrying out an overrelaxed step. """ return self._prob_overrelaxed
[docs] def name(self): """ See :meth:`pints.MCMCSampler.name()`. """ return 'Slice Sampling - Stepout'
[docs] def n_hyper_parameters(self): """ See :meth:`TunableMethod.n_hyper_parameters()`. """ return 4
[docs] def set_bisection_steps(self, a): """ Set integer for limiting the bisection process in overrelaxed steps. """ a = int(a) if a < 0: raise ValueError( 'Integer must be positive (to limit overrelaxation endpoint' ' accuracy to (2 ^ (-bisection steps) * width).') self._a = a
[docs] def set_expansion_steps(self, m): """ Set integer for limiting the interval expansion. """ m = int(m) if m <= 0: raise ValueError('Integer must be positive to limit the' ' interval size to ``integer * width``.') self._m = m
[docs] def set_hyper_parameters(self, x): """ The hyper-parameter vector is ``[width, expansion steps, prob_overrelaxed, bisection steps]``. See :meth:`TunableMethod.set_hyper_parameters()`. """ self.set_width(x[0]) self.set_expansion_steps(x[1]) self.set_prob_overrelaxed(x[2]) self.set_bisection_steps(x[3])
[docs] def set_prob_overrelaxed(self, prob): """ Set the probability of a step being overrelaxed. """ prob = float(prob) if prob < 0 or prob > 1: raise ValueError('Probability must be positive and <= 1.') self._prob_overrelaxed = prob
[docs] def set_width(self, w): """ Sets the width for generating the interval. This can either be a single number or an array with the same number of elements as the number of variables to update. """ if np.isscalar(w): w = np.ones(self._n_parameters) * w else: w = np.array(w, copy=True) if len(w) != self._n_parameters: raise ValueError( 'Width for interval expansion must a scalar or an array' ' of length n_parameters.') if np.any(w < 0): raise ValueError('Width for interval expansion must be positive.') self._w = w
[docs] def tell(self, fx): """ See :meth:`pints.SingleChainMCMC.tell()`. """ # Check ask/tell pattern if not self._ready_for_tell: raise RuntimeError('Tell called before proposal was set.') self._ready_for_tell = False # Ensure fx is a float fx = float(fx) # Very first call if self._current is None: # Check first point is somewhere sensible if not np.isfinite(fx): raise ValueError( 'Initial point for MCMC must have finite logpdf.') # Set current sample, log pdf of current sample and initialise # proposed sample for next iteration self._current = np.copy(self._x0) self._temporary_log_pdf = fx self._proposed = np.copy(self._current) # Sample height of the slice log_y for next iteration e = np.random.exponential(1) self._current_log_y = fx - e # Set flag to true as we need to initialise the interval expansion # for next iteration self._first_expansion = True # Check whether next mcmc step should be overrelaxed self._overrelaxed_step = (np.random.uniform() < self._prob_overrelaxed) if self._overrelaxed_step: self._init_overrelaxation = True self._bisection = True # Return first point in chain, which is x0 return np.copy(self._current), fx, True # While we expand the interval ``I=(l,r)``, we return None if not self._interval_found: # Set the log_pdf of the interval edge that we are expanding if self._set_l: self._fx_l = fx elif self._set_r: self._fx_r = fx elif self._init_left: self._fx_l = fx self._init_left = False elif self._init_right: self._fx_r = fx self._init_right = False return None # Overrelaxed step if self._overrelaxed_step: # When the interval is of size ``w``, narrow until mid-point # is inside the slice if (((self._r - self._l) < 1.1 * self._w[self._active_param_index]) and self._init_narrowing): self._fx_mid = fx # Once the mid-point is within the slice or narrowing limit is # reached, break narrowing loop if (self._a_bar == 0 or (self._current_log_y < self._fx_mid)): self._init_narrowing = False return None # Narrow interval if (self._current[self._active_param_index] > self._temp_mid[self._active_param_index]): self._l_bar = self._mid else: self._r_bar = self._mid self._a_bar -= 1 self._w_bar = self._w_bar / 2 return None # Apply bisection to left edge if self._set_l_bisection: if self._current_log_y >= fx: self._l_hat = (self._l_hat + self._w_bar) self._set_l_bisection = False return None # Apply bisection to right edge if self._set_r_bisection: if self._current_log_y >= fx: self._r_hat = (self._r_hat - self._w_bar) self._set_r_bisection = False # Reset flag for next bisection iteration self._bisection = True # Decrease count of bisection steps left self._a_bar -= 1 return None # If trial point is not acceptable, maintain current state if (self._proposed[self._active_param_index] < self._l_bar or self._proposed[self._active_param_index] > self._r_bar or self._current_log_y >= fx): # Reset proposal to undo last change self._proposed[self._active_param_index] = ( self._current[self._active_param_index]) # And update fx to the corresponding log pdf (needed below!) fx = self._temporary_log_pdf # Reset flags for next interval expansion self._first_expansion = True self._interval_found = False # Reset overrelaxation flags self._init_overrelaxation = True self._bisection = True # Reset active parameter indices if self._active_param_index == len(self._proposed) - 1: self._active_param_index = 0 # The accepted sample becomes the new current sample self._current = np.copy(self._proposed) # The log_pdf of the accepted sample is used to construct the # new slice self._temporary_log_pdf = fx # Sample new log_y used to define the next slice e = np.random.exponential(1) self._current_log_y = fx - e # Check whether next mcmc step should be overrelaxed self._overrelaxed_step = (np.random.uniform() < self._prob_overrelaxed) return np.copy(self._current), fx, True else: self._temporary_log_pdf = fx self._active_param_index += 1 return None # Normal Stepout step else: # Do ``Threshold Check`` to check if the proposed point is within # the slice if self._current_log_y < fx: self._first_expansion = True self._interval_found = False # Reset active parameter indices if self._active_param_index == len(self._proposed) - 1: self._active_param_index = 0 # The accepted sample becomes the new current sample self._current = np.copy(self._proposed) # The log_pdf of the accepted sample is used to construct # the new slice self._temporary_log_pdf = fx # Sample new log_y used to define the next slice e = np.random.exponential(1) self._current_log_y = fx - e # Check whether next mcmc step should be overrelaxed self._overrelaxed_step = (np.random.uniform() < self._prob_overrelaxed) if self._overrelaxed_step: self._init_overrelaxation = True self._bisection = True return np.copy(self._current), fx, True else: self._temporary_log_pdf = fx self._active_param_index += 1 return None # If the trial point is rejected in the ``Threshold Check``, shrink # the interval if (self._proposed[self._active_param_index] < self._current[self._active_param_index]): self._l = self._proposed[self._active_param_index] self._temp_l[self._active_param_index] = self._l else: self._r = self._proposed[self._active_param_index] self._temp_r[self._active_param_index] = self._r return None
[docs] def width(self): """ Returns the width used for generating the interval. """ return np.copy(self._w)