#
# Plots pairwise scatterplots for all parameter pairs
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import warnings
from distutils.version import LooseVersion
import numpy as np
import scipy.stats as stats
[docs]
def pairwise(samples,
kde=False,
heatmap=False,
opacity=None,
n_percentiles=None,
parameter_names=None,
ref_parameters=None):
"""
Takes a markov chain or list of ``samples`` and creates a set of pairwise
scatterplots for all parameters (p1 versus p2, p1 versus p3, p2 versus p3,
etc.).
The returned plot is in a 'matrix' form, with histograms of each individual
parameter on the diagonal, and scatter plots of parameters ``i`` and ``j``
on each entry ``(i, j)`` below the diagonal.
Returns a ``matplotlib`` figure object and axes handle.
Parameters
----------
samples
A list of samples, with shape ``(n_samples, n_parameters)``, where
``n_samples`` is the number of samples in the list and ``n_parameters``
is the number of parameters.
kde
Set to ``True`` to use kernel-density estimation for the
histograms and scatter plots. Cannot use together with ``heatmap``.
heatmap
Set to ``True`` to plot heatmap for the pairwise plots.
Cannot be used together with ``kde``.
opacity
This value can be used to manually set the opacity of the
points in the scatter plots (when ``kde=False`` and ``heatmap=False``
only).
n_percentiles
Shows only the middle n-th percentiles of the distribution.
Default shows all samples in ``samples``.
parameter_names
A list of parameter names, which will be displayed on the axes of the
subplots. If no names are provided, the parameters are enumerated.
ref_parameters
A set of parameters for reference in the plot. For example,
if true values of parameters are known, they can be passed in for
plotting.
"""
import matplotlib
import matplotlib.pyplot as plt
# Check matplotlib version
use_old_matplotlib = LooseVersion(matplotlib.__version__) \
< LooseVersion("2.2")
# Check options kde and heatmap
if kde and heatmap:
raise ValueError('Cannot use `kde` and `heatmap` together.')
# Check samples size
try:
n_sample, n_param = samples.shape
except ValueError:
raise ValueError('`samples` must be of shape (n_sample,'
+ ' n_parameters).')
# Check number of parameters
if n_param < 2:
raise ValueError('Number of parameters must be larger than 2.')
# Check parameter names
if parameter_names is None:
parameter_names = ['Parameter' + str(i + 1) for i in range(n_param)]
elif len(parameter_names) != n_param:
raise ValueError(
'Length of `parameter_names` must be same as number of'
' parameters.')
# Check reference parameters
if ref_parameters is not None:
if len(ref_parameters) != n_param:
raise ValueError(
'Length of `ref_parameters` must be same as number of'
' parameters.')
# Create figure
fig_size = (3 * n_param, 3 * n_param)
fig, axes = plt.subplots(n_param, n_param, figsize=fig_size)
bins = 25
for i in range(n_param):
for j in range(n_param):
if i == j:
# Diagonal: Plot a histogram
if n_percentiles is None:
xmin, xmax = np.min(samples[:, i]), np.max(samples[:, i])
else:
xmin = np.percentile(samples[:, i],
50 - n_percentiles / 2.)
xmax = np.percentile(samples[:, i],
50 + n_percentiles / 2.)
xbins = np.linspace(xmin, xmax, bins)
axes[i, j].set_xlim(xmin, xmax)
if use_old_matplotlib: # pragma: no cover
axes[i, j].hist(samples[:, i], bins=xbins, normed=True)
else:
axes[i, j].hist(samples[:, i], bins=xbins, density=True)
# Add kde plot
if kde:
x = np.linspace(xmin, xmax, 100)
axes[i, j].plot(x, stats.gaussian_kde(samples[:, i])(x))
# Add reference parameters if given
if ref_parameters is not None:
ymin_tv, ymax_tv = axes[i, j].get_ylim()
axes[i, j].plot(
[ref_parameters[i], ref_parameters[i]],
[0.0, ymax_tv],
'--', c='k')
elif i < j:
# Top-right: no plot
axes[i, j].axis('off')
else:
# Lower-left: Plot the samples as density map
if n_percentiles is None:
xmin, xmax = np.min(samples[:, j]), np.max(samples[:, j])
ymin, ymax = np.min(samples[:, i]), np.max(samples[:, i])
else:
xmin = np.percentile(samples[:, j],
50 - n_percentiles / 2.)
xmax = np.percentile(samples[:, j],
50 + n_percentiles / 2.)
ymin = np.percentile(samples[:, i],
50 - n_percentiles / 2.)
ymax = np.percentile(samples[:, i],
50 + n_percentiles / 2.)
axes[i, j].set_xlim(xmin, xmax)
axes[i, j].set_ylim(ymin, ymax)
if not kde and not heatmap:
# Create scatter plot
# Determine point opacity
num_points = len(samples[:, i])
if opacity is None:
if num_points < 10:
opacity = 1.0
else:
opacity = 1.0 / np.log10(num_points)
# Scatter points
axes[i, j].scatter(
samples[:, j], samples[:, i], alpha=opacity, s=0.1)
elif kde:
# Create a KDE-based plot
# Plot values
values = np.vstack([samples[:, j], samples[:, i]])
axes[i, j].imshow(
np.rot90(values), cmap=plt.cm.Blues,
extent=[xmin, xmax, ymin, ymax])
# Create grid
xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([xx.ravel(), yy.ravel()])
# Get kernel density estimate and plot contours
kernel = stats.gaussian_kde(values)
f = np.reshape(kernel(positions).T, xx.shape)
axes[i, j].contourf(xx, yy, f, cmap='Blues')
axes[i, j].contour(xx, yy, f, colors='k')
# Force equal aspect ratio
# Matplotlib raises a warning here (on 2.7 at least)
# We can't do anything about it, so no other option than
# to suppress it at this stage...
with warnings.catch_warnings():
warnings.simplefilter('ignore', UnicodeWarning)
axes[i, j].set_aspect((xmax - xmin) / (ymax - ymin))
elif heatmap:
# Create a heatmap-based plot
# Create bins
xbins = np.linspace(xmin, xmax, bins)
ybins = np.linspace(ymin, ymax, bins)
# Plot heatmap
axes[i, j].hist2d(samples[:, j], samples[:, i],
bins=(xbins, ybins), cmap=plt.cm.Blues)
# Force equal aspect ratio
# Matplotlib raises a warning here (on 2.7 at least)
# We can't do anything about it, so no other option than
# to suppress it at this stage...
with warnings.catch_warnings():
warnings.simplefilter('ignore', UnicodeWarning)
axes[i, j].set_aspect((xmax - xmin) / (ymax - ymin))
# Add reference parameters if given
if ref_parameters is not None:
axes[i, j].plot(
[ref_parameters[j], ref_parameters[j]],
[ymin, ymax],
'--', c='k')
axes[i, j].plot(
[xmin, xmax],
[ref_parameters[i], ref_parameters[i]],
'--', c='k')
# Set tick labels
if i < n_param - 1:
# Only show x tick labels for the last row
axes[i, j].set_xticklabels([])
else:
# Rotate the x tick labels to fit in the plot
for tl in axes[i, j].get_xticklabels():
tl.set_rotation(45)
if j > 0:
# Only show y tick labels for the first column
axes[i, j].set_yticklabels([])
# Set axis labels
axes[-1, i].set_xlabel(parameter_names[i])
if i == 0:
# The first one is not a parameter
axes[i, 0].set_ylabel('Frequency')
else:
axes[i, 0].set_ylabel(parameter_names[i])
return fig, axes