#
# Plot a single histogram
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
from distutils.version import LooseVersion
import numpy as np
import scipy.stats as stats
[docs]
def histogram(
samples,
kde=False,
n_percentiles=None,
parameter_names=None,
ref_parameters=None):
"""
Takes one or more markov chains or lists of samples as input and creates
and returns a plot showing histograms for each chain or list of samples.
Returns a ``matplotlib`` figure object and axes handle.
Parameters
----------
samples
A list of lists of samples, with shape
``(n_lists, n_samples, n_parameters)``, where ``n_lists`` is the
number of lists of samples, ``n_samples`` is the number of samples in
one list and ``n_parameters`` is the number of parameters.
kde
Set to ``True`` to include kernel-density estimation for the
histograms.
n_percentiles
Shows only the middle n-th percentiles of the distribution.
Default shows all samples in ``samples``.
parameter_names
A list of parameter names, which will be displayed on the x-axis of the
histogram subplots. If no names are provided, the parameters are
enumerated.
ref_parameters
A set of parameters for reference in the plot. For example, if true
values of parameters are known, they can be passed in for plotting.
"""
import matplotlib
import matplotlib.pyplot as plt
# Check matplotlib version
use_old_matplotlib = LooseVersion(matplotlib.__version__) \
< LooseVersion("2.2")
# If we switch to Python3 exclusively, bins and alpha can be keyword-only
# arguments
bins = 40
alpha = 0.5
samples = np.array(samples)
n_list = len(samples)
_, n_param = samples[0].shape
# Check parameter names
if parameter_names is None:
parameter_names = ['Parameter' + str(i + 1) for i in range(n_param)]
elif len(parameter_names) != n_param:
raise ValueError(
'Length of `parameter_names` must be same as number of'
' parameters.')
# Check reference parameters
if ref_parameters is not None:
if len(ref_parameters) != n_param:
raise ValueError(
'Length of `ref_parameters` must be same as number of'
' parameters.')
# Set up figure
fig, axes = plt.subplots(
n_param, 1, figsize=(6, 2 * n_param),
squeeze=False, # Tell matlab to always return a 2d axes object
)
# Find ranges across all samples
stacked_chains = np.vstack(samples)
if n_percentiles is None:
xmin = np.min(stacked_chains, axis=0)
xmax = np.max(stacked_chains, axis=0)
else:
xmin = np.percentile(stacked_chains,
50 - n_percentiles / 2.,
axis=0)
xmax = np.percentile(stacked_chains,
50 + n_percentiles / 2.,
axis=0)
xbins = np.linspace(xmin, xmax, bins)
# Plot first samples
for i in range(n_param):
for j_list, samples_j in enumerate(samples):
# Add histogram subplot
axes[i, 0].set_xlabel(parameter_names[i])
axes[i, 0].set_ylabel('Frequency')
if use_old_matplotlib: # pragma: no cover
axes[i, 0].hist(
samples_j[:, i], bins=xbins[:, i], alpha=alpha,
normed=True, label='Samples ' + str(1 + j_list))
else:
axes[i, 0].hist(
samples_j[:, i], bins=xbins[:, i], alpha=alpha,
density=True, label='Samples ' + str(1 + j_list))
# Add kde plot
if kde:
x = np.linspace(xmin[i], xmax[i], 100)
axes[i, 0].plot(x, stats.gaussian_kde(samples_j[:, i])(x))
# Add reference parameters if given
if ref_parameters is not None:
# For histogram subplot
ymin_tv, ymax_tv = axes[i, 0].get_ylim()
axes[i, 0].plot(
[ref_parameters[i], ref_parameters[i]],
[0.0, ymax_tv],
'--', c='k')
if n_list > 1:
axes[0, 0].legend()
plt.tight_layout()
return fig, axes[:, 0]