Source code for pints.plot._pairwise

#
# Plots pairwise scatterplots for all parameter pairs
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import warnings

import numpy as np
import scipy.stats as stats


[docs] def pairwise(samples, kde=False, heatmap=False, opacity=None, n_percentiles=None, parameter_names=None, ref_parameters=None): """ Takes a markov chain or list of ``samples`` and creates a set of pairwise scatterplots for all parameters (p1 versus p2, p1 versus p3, p2 versus p3, etc.). The returned plot is in a 'matrix' form, with histograms of each individual parameter on the diagonal, and scatter plots of parameters ``i`` and ``j`` on each entry ``(i, j)`` below the diagonal. Returns a ``matplotlib`` figure object and axes handle. Parameters ---------- samples A list of samples, with shape ``(n_samples, n_parameters)``, where ``n_samples`` is the number of samples in the list and ``n_parameters`` is the number of parameters. kde Set to ``True`` to use kernel-density estimation for the histograms and scatter plots. Cannot use together with ``heatmap``. heatmap Set to ``True`` to plot heatmap for the pairwise plots. Cannot be used together with ``kde``. opacity This value can be used to manually set the opacity of the points in the scatter plots (when ``kde=False`` and ``heatmap=False`` only). n_percentiles Shows only the middle n-th percentiles of the distribution. Default shows all samples in ``samples``. parameter_names A list of parameter names, which will be displayed on the axes of the subplots. If no names are provided, the parameters are enumerated. ref_parameters A set of parameters for reference in the plot. For example, if true values of parameters are known, they can be passed in for plotting. """ import matplotlib.pyplot as plt # Check options kde and heatmap if kde and heatmap: raise ValueError('Cannot use `kde` and `heatmap` together.') # Check samples size try: n_sample, n_param = samples.shape except ValueError: raise ValueError('`samples` must be of shape (n_sample,' + ' n_parameters).') # Check number of parameters if n_param < 2: raise ValueError('Number of parameters must be larger than 2.') # Check parameter names if parameter_names is None: parameter_names = ['Parameter' + str(i + 1) for i in range(n_param)] elif len(parameter_names) != n_param: raise ValueError( 'Length of `parameter_names` must be same as number of' ' parameters.') # Check reference parameters if ref_parameters is not None: if len(ref_parameters) != n_param: raise ValueError( 'Length of `ref_parameters` must be same as number of' ' parameters.') # Create figure fig_size = (3 * n_param, 3 * n_param) fig, axes = plt.subplots(n_param, n_param, figsize=fig_size) bins = 25 for i in range(n_param): for j in range(n_param): if i == j: # Diagonal: Plot a histogram if n_percentiles is None: xmin, xmax = np.min(samples[:, i]), np.max(samples[:, i]) else: xmin = np.percentile(samples[:, i], 50 - n_percentiles / 2.) xmax = np.percentile(samples[:, i], 50 + n_percentiles / 2.) xbins = np.linspace(xmin, xmax, bins) axes[i, j].set_xlim(xmin, xmax) axes[i, j].hist(samples[:, i], bins=xbins, density=True) # Add kde plot if kde: x = np.linspace(xmin, xmax, 100) axes[i, j].plot(x, stats.gaussian_kde(samples[:, i])(x)) # Add reference parameters if given if ref_parameters is not None: ymin_tv, ymax_tv = axes[i, j].get_ylim() axes[i, j].plot( [ref_parameters[i], ref_parameters[i]], [0.0, ymax_tv], '--', c='k') elif i < j: # Top-right: no plot axes[i, j].axis('off') else: # Lower-left: Plot the samples as density map if n_percentiles is None: xmin, xmax = np.min(samples[:, j]), np.max(samples[:, j]) ymin, ymax = np.min(samples[:, i]), np.max(samples[:, i]) else: xmin = np.percentile(samples[:, j], 50 - n_percentiles / 2.) xmax = np.percentile(samples[:, j], 50 + n_percentiles / 2.) ymin = np.percentile(samples[:, i], 50 - n_percentiles / 2.) ymax = np.percentile(samples[:, i], 50 + n_percentiles / 2.) axes[i, j].set_xlim(xmin, xmax) axes[i, j].set_ylim(ymin, ymax) if not kde and not heatmap: # Create scatter plot # Determine point opacity num_points = len(samples[:, i]) if opacity is None: if num_points < 10: opacity = 1.0 else: opacity = 1.0 / np.log10(num_points) # Scatter points axes[i, j].scatter( samples[:, j], samples[:, i], alpha=opacity, s=0.1) elif kde: # Create a KDE-based plot # Plot values values = np.vstack([samples[:, j], samples[:, i]]) axes[i, j].imshow( np.rot90(values), cmap=plt.cm.Blues, extent=[xmin, xmax, ymin, ymax]) # Create grid xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j] positions = np.vstack([xx.ravel(), yy.ravel()]) # Get kernel density estimate and plot contours kernel = stats.gaussian_kde(values) f = np.reshape(kernel(positions).T, xx.shape) axes[i, j].contourf(xx, yy, f, cmap='Blues') axes[i, j].contour(xx, yy, f, colors='k') # Force equal aspect ratio # Matplotlib raises a warning here (on 2.7 at least) # We can't do anything about it, so no other option than # to suppress it at this stage... with warnings.catch_warnings(): warnings.simplefilter('ignore', UnicodeWarning) axes[i, j].set_aspect((xmax - xmin) / (ymax - ymin)) elif heatmap: # Create a heatmap-based plot # Create bins xbins = np.linspace(xmin, xmax, bins) ybins = np.linspace(ymin, ymax, bins) # Plot heatmap axes[i, j].hist2d(samples[:, j], samples[:, i], bins=(xbins, ybins), cmap=plt.cm.Blues) # Force equal aspect ratio # Matplotlib raises a warning here (on 2.7 at least) # We can't do anything about it, so no other option than # to suppress it at this stage... with warnings.catch_warnings(): warnings.simplefilter('ignore', UnicodeWarning) axes[i, j].set_aspect((xmax - xmin) / (ymax - ymin)) # Add reference parameters if given if ref_parameters is not None: axes[i, j].plot( [ref_parameters[j], ref_parameters[j]], [ymin, ymax], '--', c='k') axes[i, j].plot( [xmin, xmax], [ref_parameters[i], ref_parameters[i]], '--', c='k') # Set tick labels if i < n_param - 1: # Only show x tick labels for the last row axes[i, j].set_xticklabels([]) else: # Rotate the x tick labels to fit in the plot for tl in axes[i, j].get_xticklabels(): tl.set_rotation(45) if j > 0: # Only show y tick labels for the first column axes[i, j].set_yticklabels([]) # Set axis labels axes[-1, i].set_xlabel(parameter_names[i]) if i == 0: # The first one is not a parameter axes[i, 0].set_ylabel('Frequency') else: axes[i, 0].set_ylabel(parameter_names[i]) return fig, axes