Source code for aurora.plot_tools

import matplotlib.pyplot as plt
import numpy as np, copy
import matplotlib.gridspec as mplgs
import matplotlib.widgets as mplw
import itertools
plt.ion()


[docs]def slider_plot(x, y, z, xlabel='', ylabel='', zlabel='', labels=None, plot_sum=False, x_line=None, y_line=None, **kwargs): """Make a plot to explore multidimensional data. Args: x : array of float, (`M`,) The abscissa. (in aurora, often this may be rhop) y : array of float, (`N`,) The variable to slide over. (in aurora, often this may be time) z : array of float, (`P`, `M`, `N`) The variables to plot. xlabel : str, optional The label for the abscissa. ylabel : str, optional The label for the slider. zlabel : str, optional The label for the ordinate. labels : list of str with length `P` The labels for each curve in `z`. plot_sum : bool, optional If True, will also plot the sum over all `P` cases. Default is False. x_line : float, optional x coordinate at which a vertical line will be drawn. y_line : float, optional y coordinate at which a horizontal line will be drawn. """ if labels is None: labels = ['' for v in z] # make sure not to modify the z array in place zz = copy.deepcopy(z) fig = plt.figure() fig.set_size_inches(10,7, forward=True) # separate plot into 3 subgrids a_plot = plt.subplot2grid((10,10),(0,0),rowspan = 8, colspan = 8, fig=fig) a_legend = plt.subplot2grid((10,10),(0,8),rowspan = 8, colspan = 2, fig=fig) a_slider = plt.subplot2grid((10,10),(9,0),rowspan = 1, colspan = 8, fig=fig) a_plot.set_xlabel(xlabel) a_plot.set_ylabel(zlabel) if x_line is not None: a_plot.axvline(x_line, c='r',ls=':',lw=0.5) if y_line is not None: a_plot.axhline(y_line, c='r',ls=':',lw=0.5) ls_cycle = get_ls_cycle() l = [] # plot all lines for v, l_ in zip(zz, labels): ls = next(ls_cycle) tmp, = a_plot.plot(x, v[:, 0], ls, **kwargs) _ = a_legend.plot([], [], ls, label=l_, **kwargs) l.append(tmp) if plot_sum: # add sum of the first axis to the plot (and legend) ls = next(ls_cycle) l_sum, = a_plot.plot(x, zz[:, :, 0].sum(axis=0), ls, **kwargs) _ = a_legend.plot([],[], ls, label='total', **kwargs) leg=a_legend.legend(loc='best', fontsize=12).set_draggable(True) title = fig.suptitle('') a_legend.axis('off') a_slider.axis('off') def update(dum): i = int(slider.val) for v, l_ in zip(zz, l): l_.set_ydata(v[:, i]) if plot_sum: l_sum.set_ydata(zz[:, :, i].sum(axis=0)) a_plot.relim() a_plot.autoscale() title.set_text('%s = %.5f' % (ylabel, y[i]) if ylabel else '%.5f' % (y[i],)) fig.canvas.draw() def arrow_respond(slider, event): if event.key == 'right': slider.set_val(min(slider.val + 1, slider.valmax)) elif event.key == 'left': slider.set_val(max(slider.val - 1, slider.valmin)) slider = mplw.Slider( a_slider, ylabel, 0, len(y) - 1, valinit=0, valfmt='%d' ) slider.on_changed(update) update(0) fig.canvas.mpl_connect( 'key_press_event', lambda evt: arrow_respond(slider, evt) )
[docs]def get_ls_cycle(): color_vals = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] style_vals = ['-', '--', '-.', ':'] ls_vals = [] for s in style_vals: for c in color_vals: ls_vals.append(c + s) return itertools.cycle(ls_vals)
[docs]def get_color_cycle(): color_vals = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] return itertools.cycle(color_vals)
[docs]def get_line_cycle(): style_vals = ['-', '--', '-.', ':'] return itertools.cycle(style_vals)