Source code for pints.plot._trace

#
# Plot traces for a selection of samples from a chain
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import numpy as np


[docs] def trace( samples, n_percentiles=None, parameter_names=None, ref_parameters=None): """ Takes one or more markov chains or lists of samples as input and creates and returns a plot showing histograms and traces for each chain or list of samples. Returns a ``matplotlib`` figure object and axes handle. Parameters ---------- samples A list of lists of samples, with shape ``(n_lists, n_samples, n_parameters)``, where ``n_lists`` is the number of lists of samples, ``n_samples`` is the number of samples in one list and ``n_parameters`` is the number of parameters. n_percentiles Shows only the middle n-th percentiles of the distribution. Default shows all samples in ``samples``. parameter_names A list of parameter names, which will be displayed on the x-axis of the trace subplots. If no names are provided, the parameters are enumerated. ref_parameters A set of parameters for reference in the plot. For example, if true values of parameters are known, they can be passed in for plotting. """ import matplotlib.pyplot as plt # If we switch to Python3 exclusively, bins and alpha can be keyword-only # arguments bins = 40 alpha = 0.5 n_list = len(samples) _, n_param = samples[0].shape # Check number of parameters for samples_j in samples: if n_param != samples_j.shape[1]: raise ValueError( 'All samples must have the same number of parameters.' ) # Check parameter names if parameter_names is None: parameter_names = ['Parameter' + str(i + 1) for i in range(n_param)] elif len(parameter_names) != n_param: raise ValueError( 'Length of `parameter_names` must be same as number of' ' parameters.') # Check reference parameters if ref_parameters is not None: if len(ref_parameters) != n_param: raise ValueError( 'Length of `ref_parameters` must be same as number of' ' parameters.') # Set up figure fig, axes = plt.subplots( n_param, 2, figsize=(12, 2 * n_param), # Tell matplotlib to return 2d, even if n_param is 1 squeeze=False, ) # Find ranges across all samples stacked_chains = np.vstack(samples) if n_percentiles is None: xmin = np.min(stacked_chains, axis=0) xmax = np.max(stacked_chains, axis=0) else: xmin = np.percentile(stacked_chains, 50 - n_percentiles / 2., axis=0) xmax = np.percentile(stacked_chains, 50 + n_percentiles / 2., axis=0) xbins = np.linspace(xmin, xmax, bins) # Plot first samples for i in range(n_param): ymin_all, ymax_all = np.inf, -np.inf for j_list, samples_j in enumerate(samples): # Add histogram subplot axes[i, 0].set_xlabel(parameter_names[i]) axes[i, 0].set_ylabel('Frequency') axes[i, 0].hist(samples_j[:, i], bins=xbins[:, i], alpha=alpha, label='Samples ' + str(1 + j_list)) # Add trace subplot axes[i, 1].set_xlabel('Iteration') axes[i, 1].set_ylabel(parameter_names[i]) axes[i, 1].plot(samples_j[:, i], alpha=alpha) # Set ylim ymin_all = min(ymin_all, xmin[i]) ymax_all = max(ymax_all, xmax[i]) axes[i, 1].set_ylim([ymin_all, ymax_all]) # Add reference parameters if given if ref_parameters is not None: # For histogram subplot ymin_tv, ymax_tv = axes[i, 0].get_ylim() axes[i, 0].plot( [ref_parameters[i], ref_parameters[i]], [0.0, ymax_tv], '--', c='k') # For trace subplot xmin_tv, xmax_tv = axes[i, 1].get_xlim() axes[i, 1].plot( [0.0, xmax_tv], [ref_parameters[i], ref_parameters[i]], '--', c='k') if n_list > 1: axes[0, 0].legend() plt.tight_layout() return fig, axes