Source code for pints.plot._function_between_points

#
# Evaluate function between two points
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import numpy as np

import pints


[docs] def function_between_points(f, point_1, point_2, padding=0.25, evaluations=20): """ Creates and returns a plot of a function between two points in parameter space. Returns a ``matplotlib`` figure object and axes handle. Parameters ---------- f A :class:`pints.LogPDF` or :class:`pints.ErrorMeasure` to plot. point_1 The first point in parameter space. The method will find a line from ``point_1`` to ``point_2`` and plot ``f`` at several points along it. point_2 The second point. padding Specifies the amount of padding around the line segment ``[point_1, point_2]`` that will be shown in the plot. evaluations The number of evaluation along the line in parameter space. """ import matplotlib.pyplot as plt # Check function and get n_parameters if not (isinstance(f, pints.LogPDF) or isinstance(f, pints.ErrorMeasure)): raise ValueError( 'Given function must be pints.LogPDF or pints.ErrorMeasure.') n_param = f.n_parameters() # Check points point_1 = pints.vector(point_1) point_2 = pints.vector(point_2) if not (len(point_1) == len(point_2) == n_param): raise ValueError('Both points must have the same number of parameters' + ' as the given function.') # Check padding padding = float(padding) if padding < 0: raise ValueError('Padding cannot be negative.') # Check evaluation evaluations = int(evaluations) if evaluations < 3: raise ValueError('The number of evaluations must be 3 or greater.') # Figure setting fig, axes = plt.subplots(1, 1, figsize=(6, 4)) axes.set_xlabel('Point 1 to point 2') axes.set_ylabel('Function') # Generate some x-values near the given parameters s = np.linspace(-padding, 1 + padding, evaluations) # Direction r = point_2 - point_1 # Calculate function with other parameters fixed x = [point_1 + sj * r for sj in s] y = pints.evaluate(f, x, parallel=False) # Plot axes.plot(s, y, color='green') axes.axvline(0, color='#1f77b4', label='Point 1') axes.axvline(1, color='#7f7f7f', label='Point 2') axes.legend() return fig, axes