Source code for pints.plot._function
#
# Evaluate a function around a point
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#
import numpy as np
import pints
[docs]
def function(f, x, lower=None, upper=None, evaluations=20):
"""
Creates 1d plots of a :class:`LogPDF` or a :class:`ErrorMeasure` around a
point `x` (i.e. a 1-dimensional plot in each direction).
Returns a ``matplotlib`` figure object and axes handle.
Parameters
----------
f
A :class:`pints.LogPDF` or :class:`pints.ErrorMeasure` to plot.
x
A point in the function's input space.
lower
Optional lower bounds for each parameter, used to specify the lower
bounds of the plot.
upper
Optional upper bounds for each parameter, used to specify the upper
bounds of the plot.
evaluations
The number of evaluations to use in each plot.
"""
import matplotlib.pyplot as plt
# Check function and get n_parameters
if not (isinstance(f, pints.LogPDF) or isinstance(f, pints.ErrorMeasure)):
raise ValueError(
'Given function must be pints.LogPDF or pints.ErrorMeasure.')
n_param = f.n_parameters()
# Check point
x = pints.vector(x)
if len(x) != n_param:
raise ValueError(
'Given point `x` must have same number of parameters as function.')
# Check boundaries
if lower is None:
# Guess boundaries based on point x
lower = x * 0.95
lower[lower == 0] = -1
else:
lower = pints.vector(lower)
if len(lower) != n_param:
raise ValueError('Lower bounds must have same number of'
+ ' parameters as function.')
if upper is None:
# Guess boundaries based on point x
upper = x * 1.05
upper[upper == 0] = 1
else:
upper = pints.vector(upper)
if len(upper) != n_param:
raise ValueError('Upper bounds must have same number of'
+ ' parameters as function.')
# Check number of evaluations
evaluations = int(evaluations)
if evaluations < 1:
raise ValueError('Number of evaluations must be greater than zero.')
# Create points to plot
xs = np.tile(x, (n_param * evaluations, 1))
for j in range(n_param):
i1 = j * evaluations
i2 = i1 + evaluations
xs[i1:i2, j] = np.linspace(lower[j], upper[j], evaluations)
# Evaluate points
fs = pints.evaluate(f, xs, parallel=False)
# Create figure
fig, axes = plt.subplots(n_param, 1, figsize=(6, 2 * n_param))
if n_param == 1:
axes = np.asarray([axes], dtype=object)
for j, p in enumerate(x):
i1 = j * evaluations
i2 = i1 + evaluations
axes[j].plot(xs[i1:i2, j], fs[i1:i2], c='green', label='Function')
axes[j].axvline(p, c='blue', label='Value')
axes[j].set_xlabel('Parameter ' + str(1 + j))
axes[j].legend()
plt.tight_layout()
return fig, axes