#
# Plot traces for a selection of samples from a chain
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import numpy as np
[docs]
def trace(
samples,
n_percentiles=None,
parameter_names=None,
ref_parameters=None):
"""
Takes one or more markov chains or lists of samples as input and creates
and returns a plot showing histograms and traces for each chain or list of
samples.
Returns a ``matplotlib`` figure object and axes handle.
Parameters
----------
samples
A list of lists of samples, with shape
``(n_lists, n_samples, n_parameters)``, where ``n_lists`` is the
number of lists of samples, ``n_samples`` is the number of samples in
one list and ``n_parameters`` is the number of parameters.
n_percentiles
Shows only the middle n-th percentiles of the distribution.
Default shows all samples in ``samples``.
parameter_names
A list of parameter names, which will be displayed on the x-axis of the
trace subplots. If no names are provided, the parameters are
enumerated.
ref_parameters
A set of parameters for reference in the plot. For example, if true
values of parameters are known, they can be passed in for plotting.
"""
import matplotlib.pyplot as plt
# If we switch to Python3 exclusively, bins and alpha can be keyword-only
# arguments
bins = 40
alpha = 0.5
n_list = len(samples)
_, n_param = samples[0].shape
# Check number of parameters
for samples_j in samples:
if n_param != samples_j.shape[1]:
raise ValueError(
'All samples must have the same number of parameters.'
)
# Check parameter names
if parameter_names is None:
parameter_names = ['Parameter' + str(i + 1) for i in range(n_param)]
elif len(parameter_names) != n_param:
raise ValueError(
'Length of `parameter_names` must be same as number of'
' parameters.')
# Check reference parameters
if ref_parameters is not None:
if len(ref_parameters) != n_param:
raise ValueError(
'Length of `ref_parameters` must be same as number of'
' parameters.')
# Set up figure
fig, axes = plt.subplots(
n_param, 2, figsize=(12, 2 * n_param),
# Tell matplotlib to return 2d, even if n_param is 1
squeeze=False,
)
# Find ranges across all samples
stacked_chains = np.vstack(samples)
if n_percentiles is None:
xmin = np.min(stacked_chains, axis=0)
xmax = np.max(stacked_chains, axis=0)
else:
xmin = np.percentile(stacked_chains,
50 - n_percentiles / 2.,
axis=0)
xmax = np.percentile(stacked_chains,
50 + n_percentiles / 2.,
axis=0)
xbins = np.linspace(xmin, xmax, bins)
# Plot first samples
for i in range(n_param):
ymin_all, ymax_all = np.inf, -np.inf
for j_list, samples_j in enumerate(samples):
# Add histogram subplot
axes[i, 0].set_xlabel(parameter_names[i])
axes[i, 0].set_ylabel('Frequency')
axes[i, 0].hist(samples_j[:, i], bins=xbins[:, i], alpha=alpha,
label='Samples ' + str(1 + j_list))
# Add trace subplot
axes[i, 1].set_xlabel('Iteration')
axes[i, 1].set_ylabel(parameter_names[i])
axes[i, 1].plot(samples_j[:, i], alpha=alpha)
# Set ylim
ymin_all = min(ymin_all, xmin[i])
ymax_all = max(ymax_all, xmax[i])
axes[i, 1].set_ylim([ymin_all, ymax_all])
# Add reference parameters if given
if ref_parameters is not None:
# For histogram subplot
ymin_tv, ymax_tv = axes[i, 0].get_ylim()
axes[i, 0].plot(
[ref_parameters[i], ref_parameters[i]],
[0.0, ymax_tv],
'--', c='k')
# For trace subplot
xmin_tv, xmax_tv = axes[i, 1].get_xlim()
axes[i, 1].plot(
[0.0, xmax_tv],
[ref_parameters[i], ref_parameters[i]],
'--', c='k')
if n_list > 1:
axes[0, 0].legend()
plt.tight_layout()
return fig, axes