Source code for aurora.animate

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation
plt.ion()
import pickle as pkl
from . import plot_tools


[docs]def animate_aurora(x,y,z, xlabel='', ylabel='', zlabel='', labels=None, plot_sum=False, uniform_y_spacing=True, save_filename=None): ''' Produce animation of time- and radially-dependent results from aurora. Args: x : array of float, (`M`,) The abscissa. (in aurora, often this may be rhop) y : array of float, (`N`,) The variable to slide over. (in aurora, often this may be time) z : array of float, (`P`, `M`, `N`) The variables to plot. xlabel : str, optional The label for the abscissa. ylabel : str, optional The label for the animated coordinate. This is expected in a format such that ylabel.format(y_val) will display a good moving label, e.g. ylabel='t={:.4f} s'. zlabel : str, optional The label for the ordinate. labels : list of str with length `P` The labels for each curve in `z`. plot_sum : bool, optional If True, will also plot the sum over all `P` cases. Default is False. uniform_y_spacing : bool, optional If True, interpolate values in z onto a uniformly-spaced y grid save_filename : str If a valid path/filename is provided, the animation will be saved here in mp4 format. ''' if labels is None: labels = ['' for v in z] if plot_sum: labels.append('total') z_sum = np.sum(z, axis=0) z = np.vstack((z, np.atleast_3d(z_sum).transpose(2,0,1))) if uniform_y_spacing: from scipy.interpolate import RegularGridInterpolator as rgi interp_fun = rgi((np.arange(z.shape[0]), x, y), z) y_eq = np.linspace(min(y), max(y), len(y)) new_grid = np.ix_(np.arange(z.shape[0]), x, y_eq) z = interp_fun(new_grid) # set up the figure and side space for legend fig = plt.figure(figsize=(10,6)) a_plot = plt.subplot2grid((1,10),(0,0),rowspan = 1, colspan = 8, fig=fig) a_legend = plt.subplot2grid((1,10),(0,8),rowspan = 1, colspan = 2, fig=fig) a_plot.set_xlabel(xlabel) a_plot.set_ylabel(zlabel) a_plot.set_xlim([np.min(x),np.max(x)]) a_plot.set_ylim([0.0, 1.1*np.max(z)]) # get nice sequence of line styles/colors ls_cycle = plot_tools.get_ls_cycle() lines = [] for l_ in labels: a_legend.plot([],[],'k-' if l_=='total' else next(ls_cycle), lw=2.5 if l_=='total' else 1.0, label=l_)[0] lobj = a_plot.plot([],[],'k-' if l_=='total' else next(ls_cycle), lw=2.5 if l_=='total' else 1.0)[0] lines.append(lobj) # time label (NB: update won't work if this is placed outside axes) y_text = a_plot.text(0.75, 0.95, ' ', fontsize=14, transform=a_plot.transAxes) def init(): # initialization function for line in lines: line.set_data([], []) y_text.set_text('') return tuple(lines) + (y_text,) def animate(i): # animation function, called sequentially y_text.set_text(ylabel.format(y_eq[i] if uniform_y_spacing else y[i])) for lnum,line in enumerate(lines): line.set_data(x, z[lnum,:,i]) return tuple(lines) + (y_text,) a_legend.legend(loc='center').set_draggable(True) a_legend.axis('off') # run animation now: anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(y), interval=20, blit=True) if save_filename is not None: anim.save(save_filename+'.mp4',fps=30, extra_args=['-vcodec','libx264'])