Source code for pints.plot._series

#
# Plots predicted posterior time series
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import numpy as np


[docs] def series(samples, problem, ref_parameters=None, thinning=None): """ Creates and returns a plot of predicted time series for a given list of ``samples`` and a single-output or multi-output ``problem``. Because this method runs simulations, it can take a considerable time to run. Returns a ``matplotlib`` figure object and axes handle. Parameters ---------- samples A list of samples, with shape ``(n_samples, n_parameters)``, where `n_samples` is the number of samples in the list and ``n_parameters`` is the number of parameters. problem A :class:``pints.SingleOutputProblem`` or :class:``pints.MultiOutputProblem`` of a n_parameters equal to or greater than the ``n_parameters`` of the `samples`. Any extra parameters present in the chain but not accepted by the ``SingleOutputProblem`` or ``MultiOutputProblem`` (for example parameters added by a noise model) will be ignored. ref_parameters A set of parameters for reference in the plot. For example, if true values of parameters are known, they can be passed in for plotting. thinning An integer exceeding zero. If specified, only every n-th sample (with ``n = thinning``) in the samples will be used. If left at the default value ``None``, a value will be chosen so that 200 to 400 predictions are shown. """ import matplotlib.pyplot as plt # Check samples size try: n_sample, n_param = samples.shape except ValueError: raise ValueError('`samples` must be of shape (n_sample,' + ' n_parameters).') # Get problem n_parameters n_parameters = problem.n_parameters() # Check reference parameters if ref_parameters is not None: if len(ref_parameters) != n_param and \ len(ref_parameters) != n_parameters: raise ValueError( 'Length of `ref_parameters` must be same as number of' ' parameters.') ref_series = problem.evaluate(ref_parameters[:n_parameters]) # Get number of problem output n_outputs = problem.n_outputs() # Get thinning rate if thinning is None: thinning = max(1, int(n_sample / 200)) else: thinning = int(thinning) if thinning < 1: raise ValueError( 'Thinning rate must be `None` or an integer greater than' ' zero.') # Get times times = problem.times() # Evaluate the model for all parameter sets in the samples i = 0 predicted_values = [] for params in samples[::thinning, :n_parameters]: predicted_values.append(problem.evaluate(params)) i += 1 predicted_values = np.array(predicted_values) mean_values = np.mean(predicted_values, axis=0) # Guess appropriate alpha (0.05 worked for 1000 plots) alpha = min(1, max(0.05 * (1000 / (n_sample / thinning)), 0.5)) # Plot prediction fig, axes = plt.subplots(n_outputs, 1, figsize=(8, np.sqrt(n_outputs) * 3), sharex=True) if n_outputs == 1: plt.xlabel('Time') plt.ylabel('Value') plt.plot( times, problem.values(), 'x', color='#7f7f7f', ms=6.5, alpha=0.5, label='Original data') plt.plot( times, predicted_values[0], color='#1f77b4', label='Inferred series') for v in predicted_values[1:]: plt.plot(times, v, color='#1f77b4', alpha=alpha) plt.plot(times, mean_values, 'k:', lw=2, label='Mean of inferred series') # Add reference series if given if ref_parameters is not None: plt.plot(times, ref_series, color='#d62728', ls='--', label='Reference series') plt.legend() elif n_outputs > 1: # Remove horizontal space between axes and set common xlabel fig.subplots_adjust(hspace=0) axes[-1].set_xlabel('Time') # Go through each output for i_output in range(n_outputs): axes[i_output].set_ylabel('Output %d' % (i_output + 1)) axes[i_output].plot( times, problem.values()[:, i_output], 'x', color='#7f7f7f', ms=6.5, alpha=0.5, label='Original data') axes[i_output].plot( times, predicted_values[0][:, i_output], color='#1f77b4', label='Inferred series') for v in predicted_values[1:]: axes[i_output].plot(times, v[:, i_output], color='#1f77b4', alpha=alpha) axes[i_output].plot(times, mean_values[:, i_output], 'k:', lw=2, label='Mean of inferred series') # Add reference series if given if ref_parameters is not None: axes[i_output].plot(times, ref_series[:, i_output], color='#d62728', ls='--', label='Reference series') axes[0].legend() plt.tight_layout() return fig, axes