import matplotlib.pyplot as plt
import numpy as np, copy
import matplotlib.gridspec as mplgs
import matplotlib.widgets as mplw
import itertools
plt.ion()
[docs]def slider_plot(x, y, z, xlabel='', ylabel='', zlabel='', labels=None, plot_sum=False,
x_line=None, y_line=None, **kwargs):
"""Make a plot to explore multidimensional data.
Args:
x : array of float, (`M`,)
The abscissa. (in aurora, often this may be rhop)
y : array of float, (`N`,)
The variable to slide over. (in aurora, often this may be time)
z : array of float, (`P`, `M`, `N`)
The variables to plot.
xlabel : str, optional
The label for the abscissa.
ylabel : str, optional
The label for the slider.
zlabel : str, optional
The label for the ordinate.
labels : list of str with length `P`
The labels for each curve in `z`.
plot_sum : bool, optional
If True, will also plot the sum over all `P` cases. Default is False.
x_line : float, optional
x coordinate at which a vertical line will be drawn.
y_line : float, optional
y coordinate at which a horizontal line will be drawn.
"""
if labels is None:
labels = ['' for v in z]
# make sure not to modify the z array in place
zz = copy.deepcopy(z)
fig = plt.figure()
fig.set_size_inches(10,7, forward=True)
# separate plot into 3 subgrids
a_plot = plt.subplot2grid((10,10),(0,0),rowspan = 8, colspan = 8, fig=fig)
a_legend = plt.subplot2grid((10,10),(0,8),rowspan = 8, colspan = 2, fig=fig)
a_slider = plt.subplot2grid((10,10),(9,0),rowspan = 1, colspan = 8, fig=fig)
a_plot.set_xlabel(xlabel)
a_plot.set_ylabel(zlabel)
if x_line is not None: a_plot.axvline(x_line, c='r',ls=':',lw=0.5)
if y_line is not None: a_plot.axhline(y_line, c='r',ls=':',lw=0.5)
ls_cycle = get_ls_cycle()
l = []
# plot all lines
for v, l_ in zip(zz, labels):
ls = next(ls_cycle)
tmp, = a_plot.plot(x, v[:, 0], ls, **kwargs)
_ = a_legend.plot([], [], ls, label=l_, **kwargs)
l.append(tmp)
if plot_sum:
# add sum of the first axis to the plot (and legend)
ls = next(ls_cycle)
l_sum, = a_plot.plot(x, zz[:, :, 0].sum(axis=0), ls, **kwargs)
_ = a_legend.plot([],[], ls, label='total', **kwargs)
leg=a_legend.legend(loc='best', fontsize=12).set_draggable(True)
title = fig.suptitle('')
a_legend.axis('off')
a_slider.axis('off')
def update(dum):
i = int(slider.val)
for v, l_ in zip(zz, l):
l_.set_ydata(v[:, i])
if plot_sum:
l_sum.set_ydata(zz[:, :, i].sum(axis=0))
a_plot.relim()
a_plot.autoscale()
title.set_text('%s = %.5f' % (ylabel, y[i]) if ylabel else '%.5f' % (y[i],))
fig.canvas.draw()
def arrow_respond(slider, event):
if event.key == 'right':
slider.set_val(min(slider.val + 1, slider.valmax))
elif event.key == 'left':
slider.set_val(max(slider.val - 1, slider.valmin))
slider = mplw.Slider(
a_slider,
ylabel,
0,
len(y) - 1,
valinit=0,
valfmt='%d'
)
slider.on_changed(update)
update(0)
fig.canvas.mpl_connect(
'key_press_event',
lambda evt: arrow_respond(slider, evt)
)
[docs]def get_ls_cycle():
color_vals = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
style_vals = ['-', '--', '-.', ':']
ls_vals = []
for s in style_vals:
for c in color_vals:
ls_vals.append(c + s)
return itertools.cycle(ls_vals)
[docs]def get_color_cycle():
color_vals = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
return itertools.cycle(color_vals)
[docs]def get_line_cycle():
style_vals = ['-', '--', '-.', ':']
return itertools.cycle(style_vals)