Source code for aurora.plot_tools

# MIT License
#
# Copyright (c) 2021 Francesco Sciortino
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np, copy
import matplotlib.gridspec as mplgs
import matplotlib.widgets as mplw
from matplotlib.cm import ScalarMappable
import itertools

plt.ion()


[docs]def slider_plot( x, y, z, xlabel="", ylabel="", zlabel="", labels=None, plot_sum=False, x_line=None, y_line=None, **kwargs ): """Make a plot to explore multidimensional data. Parameters ---------- x : array of float, (`M`,) The abscissa. (in aurora, often this may be rhop) y : array of float, (`N`,) The variable to slide over. (in aurora, often this may be time) z : array of float, (`P`, `M`, `N`) The variables to plot. xlabel : str, optional The label for the abscissa. ylabel : str, optional The label for the slider. zlabel : str, optional The label for the ordinate. labels : list of str with length `P` The labels for each curve in `z`. plot_sum : bool, optional If True, will also plot the sum over all `P` cases. Default is False. x_line : float, optional x coordinate at which a vertical line will be drawn. y_line : float, optional y coordinate at which a horizontal line will be drawn. """ if labels is None: labels = ["" for v in z] # make sure not to modify the z array in place zz = copy.deepcopy(z) fig = plt.figure() fig.set_size_inches(10, 7, forward=True) # separate plot into 3 subgrids a_plot = plt.subplot2grid((10, 10), (0, 0), rowspan=8, colspan=8, fig=fig) a_legend = plt.subplot2grid((10, 10), (0, 8), rowspan=8, colspan=2, fig=fig) a_slider = plt.subplot2grid((10, 10), (9, 0), rowspan=1, colspan=8, fig=fig) a_plot.set_xlabel(xlabel) a_plot.set_ylabel(zlabel) if x_line is not None: a_plot.axvline(x_line, c="r", ls=":", lw=0.5) if y_line is not None: a_plot.axhline(y_line, c="r", ls=":", lw=0.5) ls_cycle = get_ls_cycle() l = [] # plot all lines for v, l_ in zip(zz, labels): ls = next(ls_cycle) (tmp,) = a_plot.plot(x, v[:, 0], ls, **kwargs) _ = a_legend.plot([], [], ls, label=l_, **kwargs) l.append(tmp) if plot_sum: # add sum of the first axis to the plot (and legend) (l_sum,) = a_plot.plot( x, zz[:, :, 0].sum(axis=0), c="k", lw=mpl.rcParams["lines.linewidth"] * 2, **kwargs ) _ = a_legend.plot( [], [], c="k", lw=mpl.rcParams["lines.linewidth"] * 2, label="total", **kwargs ) leg = a_legend.legend(loc="best", fontsize=12).set_draggable(True) title = fig.suptitle("") a_legend.axis("off") a_slider.axis("off") def update(dum): i = int(slider.val) for v, l_ in zip(zz, l): l_.set_ydata(v[:, i]) if plot_sum: l_sum.set_ydata(zz[:, :, i].sum(axis=0)) a_plot.relim() a_plot.autoscale() title.set_text("%s = %.5f" % (ylabel, y[i]) if ylabel else "%.5f" % (y[i],)) fig.canvas.draw() def arrow_respond(slider, event): if event.key == "right": slider.set_val(min(slider.val + 1, slider.valmax)) elif event.key == "left": slider.set_val(max(slider.val - 1, slider.valmin)) slider = mplw.Slider(a_slider, ylabel, 0, len(y) - 1, valinit=0, valfmt="%d") slider.on_changed(update) update(0) fig.canvas.mpl_connect("key_press_event", lambda evt: arrow_respond(slider, evt))
[docs]def get_ls_cycle(): color_vals = ["b", "g", "r", "c", "m", "y", "k"] style_vals = ["-", "--", "-.", ":"] ls_vals = [] for s in style_vals: for c in color_vals: ls_vals.append(c + s) return itertools.cycle(ls_vals)
[docs]def get_color_cycle(num=None, map="plasma"): """Get an iterable to select different colors in a loop. Efficiently splits a chosen color map into as many (`num`) parts as needed. """ cols = ["b", "g", "r", "c", "m", "y", "k"] if num is None or num <= len(cols): return itertools.cycle(cols[:num]) cm = plt.get_cmap(map) cols = np.empty(num) for j in np.arange(num): cols[j] = cm(1.0 * j / num) return itertools.cycle(cols)
[docs]def get_line_cycle(): style_vals = ["-", "--", "-.", ":"] return itertools.cycle(style_vals)
[docs]class DraggableColorbar: """Create a draggable colorbar for matplotlib plots to enable quick changes in color scale. Example::: fig,ax = plt.subplots() cntr = ax.contourf(R, Z, vals) cbar = plt.colorbar(cntr, format='%.3g', ax=ax) cbar = DraggableColorbar(cbar,cntr) cbar.connect() """ def __init__(self, cbar, mapimage): self.cbar = cbar self.mapimage = mapimage self.press = None self.cycle = sorted( [i for i in dir(plt.cm) if hasattr(getattr(plt.cm, i), "N")] ) self.index = self.cycle.index(ScalarMappable.get_cmap(cbar).name)
[docs] def connect(self): """Matplotlib connection for button and key pressing, release, and motion.""" self.cidpress = self.cbar.patch.figure.canvas.mpl_connect( "button_press_event", self.on_press ) self.cidrelease = self.cbar.patch.figure.canvas.mpl_connect( "button_release_event", self.on_release ) self.cidmotion = self.cbar.patch.figure.canvas.mpl_connect( "motion_notify_event", self.on_motion ) self.keypress = self.cbar.patch.figure.canvas.mpl_connect( "key_press_event", self.key_press )
[docs] def on_press(self, event): """Button pressing; check if mouse is over colorbar.""" if event.inaxes != self.cbar.ax: return self.press = event.x, event.y
[docs] def key_press(self, event): """Key pressing event""" if event.key == "down": self.index += 1 elif event.key == "up": self.index -= 1 if self.index < 0: self.index = len(self.cycle) elif self.index >= len(self.cycle): self.index = 0 cmap = self.cycle[self.index] self.cbar.set_cmap(cmap) self.cbar.draw_all() self.mapimage.set_cmap(cmap) self.mapimage.get_axes().set_title(cmap) self.cbar.patch.figure.canvas.draw()
[docs] def on_motion(self, event): """Move if the mouse is over the colorbar.""" if self.press is None: return if event.inaxes != self.cbar.ax: return xprev, yprev = self.press dx = event.x - xprev dy = event.y - yprev self.press = event.x, event.y scale = self.cbar.norm.vmax - self.cbar.norm.vmin perc = 0.03 if event.button == 1: self.cbar.norm.vmin -= (perc * scale) * np.sign(dy) self.cbar.norm.vmax -= (perc * scale) * np.sign(dy) elif event.button == 3: self.cbar.norm.vmin -= (perc * scale) * np.sign(dy) self.cbar.norm.vmax += (perc * scale) * np.sign(dy) self.cbar.draw_all() self.mapimage.set_norm(self.cbar.norm) self.cbar.patch.figure.canvas.draw()
[docs] def on_release(self, event): """Upon release, reset press data""" self.press = None self.mapimage.set_norm(self.cbar.norm) self.cbar.patch.figure.canvas.draw()
[docs] def disconnect(self): self.cbar.patch.figure.canvas.mpl_disconnect(self.cidpress) self.cbar.patch.figure.canvas.mpl_disconnect(self.cidrelease) self.cbar.patch.figure.canvas.mpl_disconnect(self.cidmotion)