Source code for pints.plot._histogram

#
# Plot a single histogram
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
from distutils.version import LooseVersion

import numpy as np
import scipy.stats as stats


[docs] def histogram( samples, kde=False, n_percentiles=None, parameter_names=None, ref_parameters=None): """ Takes one or more markov chains or lists of samples as input and creates and returns a plot showing histograms for each chain or list of samples. Returns a ``matplotlib`` figure object and axes handle. Parameters ---------- samples A list of lists of samples, with shape ``(n_lists, n_samples, n_parameters)``, where ``n_lists`` is the number of lists of samples, ``n_samples`` is the number of samples in one list and ``n_parameters`` is the number of parameters. kde Set to ``True`` to include kernel-density estimation for the histograms. n_percentiles Shows only the middle n-th percentiles of the distribution. Default shows all samples in ``samples``. parameter_names A list of parameter names, which will be displayed on the x-axis of the histogram subplots. If no names are provided, the parameters are enumerated. ref_parameters A set of parameters for reference in the plot. For example, if true values of parameters are known, they can be passed in for plotting. """ import matplotlib import matplotlib.pyplot as plt # Check matplotlib version use_old_matplotlib = LooseVersion(matplotlib.__version__) \ < LooseVersion("2.2") # If we switch to Python3 exclusively, bins and alpha can be keyword-only # arguments bins = 40 alpha = 0.5 samples = np.array(samples) n_list = len(samples) _, n_param = samples[0].shape # Check parameter names if parameter_names is None: parameter_names = ['Parameter' + str(i + 1) for i in range(n_param)] elif len(parameter_names) != n_param: raise ValueError( 'Length of `parameter_names` must be same as number of' ' parameters.') # Check reference parameters if ref_parameters is not None: if len(ref_parameters) != n_param: raise ValueError( 'Length of `ref_parameters` must be same as number of' ' parameters.') # Set up figure fig, axes = plt.subplots( n_param, 1, figsize=(6, 2 * n_param), squeeze=False, # Tell matlab to always return a 2d axes object ) # Find ranges across all samples stacked_chains = np.vstack(samples) if n_percentiles is None: xmin = np.min(stacked_chains, axis=0) xmax = np.max(stacked_chains, axis=0) else: xmin = np.percentile(stacked_chains, 50 - n_percentiles / 2., axis=0) xmax = np.percentile(stacked_chains, 50 + n_percentiles / 2., axis=0) xbins = np.linspace(xmin, xmax, bins) # Plot first samples for i in range(n_param): for j_list, samples_j in enumerate(samples): # Add histogram subplot axes[i, 0].set_xlabel(parameter_names[i]) axes[i, 0].set_ylabel('Frequency') if use_old_matplotlib: # pragma: no cover axes[i, 0].hist( samples_j[:, i], bins=xbins[:, i], alpha=alpha, normed=True, label='Samples ' + str(1 + j_list)) else: axes[i, 0].hist( samples_j[:, i], bins=xbins[:, i], alpha=alpha, density=True, label='Samples ' + str(1 + j_list)) # Add kde plot if kde: x = np.linspace(xmin[i], xmax[i], 100) axes[i, 0].plot(x, stats.gaussian_kde(samples_j[:, i])(x)) # Add reference parameters if given if ref_parameters is not None: # For histogram subplot ymin_tv, ymax_tv = axes[i, 0].get_ylim() axes[i, 0].plot( [ref_parameters[i], ref_parameters[i]], [0.0, ymax_tv], '--', c='k') if n_list > 1: axes[0, 0].legend() plt.tight_layout() return fig, axes[:, 0]