Source code for pints.plot._histogram

#
# Plot a single histogram
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import numpy as np
import scipy.stats as stats


[docs] def histogram( samples, kde=False, n_percentiles=None, parameter_names=None, ref_parameters=None): """ Takes one or more markov chains or lists of samples as input and creates and returns a plot showing histograms for each chain or list of samples. Returns a ``matplotlib`` figure object and axes handle. Parameters ---------- samples A list of lists of samples, with shape ``(n_lists, n_samples, n_parameters)``, where ``n_lists`` is the number of lists of samples, ``n_samples`` is the number of samples in one list and ``n_parameters`` is the number of parameters. kde Set to ``True`` to include kernel-density estimation for the histograms. n_percentiles Shows only the middle n-th percentiles of the distribution. Default shows all samples in ``samples``. parameter_names A list of parameter names, which will be displayed on the x-axis of the histogram subplots. If no names are provided, the parameters are enumerated. ref_parameters A set of parameters for reference in the plot. For example, if true values of parameters are known, they can be passed in for plotting. """ import matplotlib.pyplot as plt # If we switch to Python3 exclusively, bins and alpha can be keyword-only # arguments bins = 40 alpha = 0.5 samples = np.array(samples) n_list = len(samples) _, n_param = samples[0].shape # Check parameter names if parameter_names is None: parameter_names = ['Parameter' + str(i + 1) for i in range(n_param)] elif len(parameter_names) != n_param: raise ValueError( 'Length of `parameter_names` must be same as number of' ' parameters.') # Check reference parameters if ref_parameters is not None: if len(ref_parameters) != n_param: raise ValueError( 'Length of `ref_parameters` must be same as number of' ' parameters.') # Set up figure fig, axes = plt.subplots( n_param, 1, figsize=(6, 2 * n_param), squeeze=False, # Tell matlab to always return a 2d axes object ) # Find ranges across all samples stacked_chains = np.vstack(samples) if n_percentiles is None: xmin = np.min(stacked_chains, axis=0) xmax = np.max(stacked_chains, axis=0) else: xmin = np.percentile(stacked_chains, 50 - n_percentiles / 2., axis=0) xmax = np.percentile(stacked_chains, 50 + n_percentiles / 2., axis=0) xbins = np.linspace(xmin, xmax, bins) # Plot first samples for i in range(n_param): for j_list, samples_j in enumerate(samples): # Add histogram subplot axes[i, 0].set_xlabel(parameter_names[i]) axes[i, 0].set_ylabel('Frequency') axes[i, 0].hist( samples_j[:, i], bins=xbins[:, i], alpha=alpha, density=True, label='Samples ' + str(1 + j_list)) # Add kde plot if kde: x = np.linspace(xmin[i], xmax[i], 100) axes[i, 0].plot(x, stats.gaussian_kde(samples_j[:, i])(x)) # Add reference parameters if given if ref_parameters is not None: # For histogram subplot ymin_tv, ymax_tv = axes[i, 0].get_ylim() axes[i, 0].plot( [ref_parameters[i], ref_parameters[i]], [0.0, ymax_tv], '--', c='k') if n_list > 1: axes[0, 0].legend() plt.tight_layout() return fig, axes[:, 0]