#
# Plot a single histogram
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import numpy as np
import scipy.stats as stats
[docs]
def histogram(
samples,
kde=False,
n_percentiles=None,
parameter_names=None,
ref_parameters=None):
"""
Takes one or more markov chains or lists of samples as input and creates
and returns a plot showing histograms for each chain or list of samples.
Returns a ``matplotlib`` figure object and axes handle.
Parameters
----------
samples
A list of lists of samples, with shape
``(n_lists, n_samples, n_parameters)``, where ``n_lists`` is the
number of lists of samples, ``n_samples`` is the number of samples in
one list and ``n_parameters`` is the number of parameters.
kde
Set to ``True`` to include kernel-density estimation for the
histograms.
n_percentiles
Shows only the middle n-th percentiles of the distribution.
Default shows all samples in ``samples``.
parameter_names
A list of parameter names, which will be displayed on the x-axis of the
histogram subplots. If no names are provided, the parameters are
enumerated.
ref_parameters
A set of parameters for reference in the plot. For example, if true
values of parameters are known, they can be passed in for plotting.
"""
import matplotlib.pyplot as plt
# If we switch to Python3 exclusively, bins and alpha can be keyword-only
# arguments
bins = 40
alpha = 0.5
samples = np.array(samples)
n_list = len(samples)
_, n_param = samples[0].shape
# Check parameter names
if parameter_names is None:
parameter_names = ['Parameter' + str(i + 1) for i in range(n_param)]
elif len(parameter_names) != n_param:
raise ValueError(
'Length of `parameter_names` must be same as number of'
' parameters.')
# Check reference parameters
if ref_parameters is not None:
if len(ref_parameters) != n_param:
raise ValueError(
'Length of `ref_parameters` must be same as number of'
' parameters.')
# Set up figure
fig, axes = plt.subplots(
n_param, 1, figsize=(6, 2 * n_param),
squeeze=False, # Tell matlab to always return a 2d axes object
)
# Find ranges across all samples
stacked_chains = np.vstack(samples)
if n_percentiles is None:
xmin = np.min(stacked_chains, axis=0)
xmax = np.max(stacked_chains, axis=0)
else:
xmin = np.percentile(stacked_chains,
50 - n_percentiles / 2.,
axis=0)
xmax = np.percentile(stacked_chains,
50 + n_percentiles / 2.,
axis=0)
xbins = np.linspace(xmin, xmax, bins)
# Plot first samples
for i in range(n_param):
for j_list, samples_j in enumerate(samples):
# Add histogram subplot
axes[i, 0].set_xlabel(parameter_names[i])
axes[i, 0].set_ylabel('Frequency')
axes[i, 0].hist(
samples_j[:, i], bins=xbins[:, i], alpha=alpha,
density=True, label='Samples ' + str(1 + j_list))
# Add kde plot
if kde:
x = np.linspace(xmin[i], xmax[i], 100)
axes[i, 0].plot(x, stats.gaussian_kde(samples_j[:, i])(x))
# Add reference parameters if given
if ref_parameters is not None:
# For histogram subplot
ymin_tv, ymax_tv = axes[i, 0].get_ylim()
axes[i, 0].plot(
[ref_parameters[i], ref_parameters[i]],
[0.0, ymax_tv],
'--', c='k')
if n_list > 1:
axes[0, 0].legend()
plt.tight_layout()
return fig, axes[:, 0]