Source code for cibrrig.videos

import logging

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.animation import FuncAnimation
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d.art3d import Line3DCollection

logging.basicConfig()
[docs] _log = logging.getLogger()
_log.setLevel(logging.INFO) # TODO: be able to modify only parts of this
[docs] PROJECTION_KWARGS = dict(lw=0.5, alpha=0.2, color="C0", cmap="RdBu_r",s=1)
[docs] TRAIL_KWARGS = dict(lw=3, alpha=1, color="C1")
[docs] HISTORY_KWARGS = dict(color="C1", lw=0.75, alpha=0.75)
[docs] AUX_KWARGS = dict(lw=0.5, color="C1")
# TODO: Organize default arguments? # TODO: Refactor to avoid repitition?
[docs] def make_aux_raster_projection_with_stims( pop, intervals, aux, aux_t, fn_out, stim_color, aux_label="", frame_step=0.01, duration=2, cmap="bone", vmax=0.5, baseline=10, fps=30, dpi=300, dims=[0, 1], lead_in=1, winsize=1, trail_length=0.05, figsize=(3, 8), azim_speed=0.1, elev_speed=0.1, rotation_delay=1, rotate=True, style="dark_background", projection_kwargs=PROJECTION_KWARGS, trail_kwargs=TRAIL_KWARGS, history_kwargs=HISTORY_KWARGS, aux_kwargs=AUX_KWARGS, ): """ Make a video of the projection with a raster and an auxiliary trace. This function creates a video of the projection with a raster plot and an auxiliary trace. It allows for various customizations, including the duration, colormap, frame rate, and rotation settings. The video is saved to the specified output file. Args: pop (Population): Population object containing the projection data. intervals (np.ndarray): Array of intervals (n x 2) for the stimulus periods. aux (np.ndarray): Auxiliary data to be plotted. aux_t (np.ndarray): Time vector for the auxiliary data. fn_out (str or Path): Output file path for the video. stim_color (str): Color for the stimulus periods. aux_label (str, optional): Label for the auxiliary trace. Defaults to "". frame_step (float, optional): Time step between frames in seconds. Defaults to 0.01. duration (float, optional): Duration of the video in seconds. Defaults to 2. cmap (str, optional): Colormap for the projection. Defaults to "bone". vmax (float, optional): Maximum value for the colormap normalization. Defaults to 0.5. baseline (int, optional): Baseline value for the projection. Defaults to 10. fps (int, optional): Frames per second for the video. Defaults to 30. dpi (int, optional): Dots per inch for the video. Defaults to 300. dims (list, optional): Dimensions of the data to be plotted. Defaults to [0, 1]. lead_in (float, optional): Lead-in time before the start of the intervals. Defaults to 1. winsize (float, optional): Window size for the projection. Defaults to 1. trail_length (float, optional): Length of the trail in the projection. Defaults to 0.05. figsize (tuple, optional): Figure size for the video. Defaults to (3, 8). azim_speed (float, optional): Speed of azimuthal rotation. Defaults to 0.1. elev_speed (float, optional): Speed of elevation rotation. Defaults to 0.1. rotation_delay (float, optional): Delay before starting the rotation. Defaults to 1. rotate (bool, optional): Whether to rotate the 3D plot. Defaults to True. style (str, optional): Matplotlib style to use. Defaults to "dark_background". projection_kwargs (dict, optional): Additional keyword arguments for the projection. Defaults to PROJECTION_KWARGS. trail_kwargs (dict, optional): Additional keyword arguments for the trail.Passed to plt.plot. Defaults to TRAIL_KWARGS. history_kwargs (dict, optional): Additional keyword arguments for the history. Passed to LineCollection. Defaults to HISTORY_KWARGS. aux_kwargs (dict, optional): Additional keyword arguments for the auxiliary trace. Passed to plt.plot. Defaults to AUX_KWARGS. Returns: None """ plt.style.use(style) assert lead_in < duration, f"{lead_in=} must be shorter than {duration=}" assert ( isinstance(intervals, np.ndarray) and intervals.shape[1] == 2 ), "Intervals must be an n x 2 array" assert pop.projection is not None, "Projection is not yet computed" t0 = intervals[0, 0] - lead_in tf = t0 + duration if baseline == 0: baseline = 30 projection_kwargs["alpha"] = 0 # Set up figure and axes layout ===== f = plt.figure(figsize=figsize, dpi=dpi) gs = f.add_gridspec(nrows=15, ncols=1) if len(dims) == 3: ax = f.add_subplot(gs[8:, :], projection="3d") else: ax = f.add_subplot(gs[8:, :]) ax_raster = f.add_subplot(gs[1:8, :]) ax_aux = f.add_subplot(gs[0, :], sharex=ax_raster) # Plot baseline if baseline > 0: ax = pop.plot_projection_line( dims=dims, t0=t0 - baseline, tf=t0, ax=ax, **projection_kwargs ) # Plot trail================================ s0, sf = np.searchsorted(pop.tbins, [t0 - trail_length, t0]) # Plot thick line (current timepoints) if len(dims) == 3: (trail1,) = ax.plot( pop.projection[s0:sf, dims[0]], pop.projection[s0:sf, dims[1]], pop.projection[s0:sf, dims[2]], **trail_kwargs, ) # NB: the comma after trail1 is important. else: (trail1,) = ax.plot( pop.projection[s0:sf, dims[0]], pop.projection[s0:sf, dims[1]], **trail_kwargs, ) # NB: the comma after trail1 is important. # Plot thin line (all previous timepoints) s0, sf = np.searchsorted(pop.tbins, [t0, tf]) segments = np.stack( [pop.projection[s0 : sf - 1, dims], pop.projection[s0 + 1 : sf, dims]], axis=1 ) if len(dims) == 3: history = Line3DCollection(segments, **history_kwargs) else: history = LineCollection(segments, **history_kwargs) ax.add_collection(history) history.set_color("none") # Plot Raster ============================ cell_ids = np.arange(pop.cbins.shape[0]) s0, sf = np.searchsorted(pop.tbins, [t0 - winsize, t0 + winsize]) quad = ax_raster.pcolormesh( pop.tbins[s0:sf] - t0, cell_ids, pop.raster_smoothed[:, s0:sf], cmap=cmap, vmax=vmax, ) ax_raster.axvline(0, color=plt.rcParams["text.color"], ls=":", lw=2) ax_raster.axis("off") # annotations ax_raster.set_yticks([]) ax_raster.set_xlabel("Time (s)") ax_raster.set_xticks([-winsize, 0, winsize]) ax_raster.set_xlim([-winsize * 1.1, winsize]) ax_raster.vlines(-winsize * 1.05, 0, 25, lw=3) ax_raster.text( winsize * -1.08, 0, "25 neurons", rotation=90, ha="right", va="bottom" ) ax_raster.hlines(-cell_ids.shape[0] * 0.05, -winsize, -winsize + winsize / 5, lw=3) ax_raster.text( -winsize, -cell_ids.shape[0] * 0.075, f"{winsize/5*1000:0.0f}ms", ha="left", va="top", ) # Plot aux ================================= s0, sf = np.searchsorted(aux_t, [t0 - winsize, t0 + winsize]) (dd,) = ax_aux.plot(aux_t[s0:sf] - t0, aux[s0:sf], **aux_kwargs) # annotations ax_aux.axis("off") ax_aux.axvline(0, ls=":", lw=2) s0, sf = np.searchsorted(aux_t, [t0, tf]) yy = np.min(aux[s0:sf]), np.max(aux[s0:sf]) ax_aux.set_ylim(yy[0], yy[1] * 1.1) ax_aux.text( ax_aux.get_xlim()[0], np.mean(yy), aux_label, rotation=90, ha="right", va="center", color=aux_kwargs["color"], ) # layout ymax = ax_aux.get_ylim()[1] aux_stims = ax_aux.hlines( np.ones(intervals.shape[0]) * ymax * 0.9, intervals[:, 0] - t0, intervals[:, 1] - t0, color=stim_color, lw=4, ) # Limits of the projection ax = _trim_axes(ax, pop, t0, tf, dims) if len(dims) == 3: _plot_xy_plane(ax, color=plt.rcParams["text.color"], alpha=0.25) # INITIALIZE def init(): return (trail1,) # UPDATE - This is the main loop to update each frame of the video. Using the "set_data" method will update the image object with new data for the new frame def update(frames): this_t = t0 + frames is_stim = False if np.any(np.logical_and(intervals[:, 0] < this_t, intervals[:, 1] > this_t)): is_stim = True # Update the trajectory (current time points) s0, sf = np.searchsorted( pop.tbins, [t0 + frames - trail_length, t0 + frames] ) # Use "frames" to get a new slice into the data (i.e., maps time into samples) trail1.set_data(pop.projection[s0:sf, dims[0]], pop.projection[s0:sf, dims[1]]) if len(dims) == 3: trail1.set_3d_properties(pop.projection[s0:sf, dims[2]]) # update the data if is_stim: trail1.set_color(stim_color) else: trail1.set_color(trail_kwargs["color"]) # Sets the color # Update the histroy trajectory (all previous time points) colors = np.array([trail_kwargs["color"]] * (segments.shape[0]), dtype="object") s0, sf = np.searchsorted( pop.tbins, [t0, tf] ) # Use "frames" to get a new slice into the data (i.e., maps time into samples) for t1, t2 in intervals: mask = (pop.tbins[s0 : sf - 1] >= t1) & (pop.tbins[s0 : sf - 1] <= t2) colors[mask] = stim_color history.set_color(colors) alphas = np.ones(segments.shape[0]) * history_kwargs["alpha"] invisible_mask = pop.tbins[s0 : sf - 1] > this_t alphas[invisible_mask] = 0 history.set_alpha(alphas) # Get new samples for the raster image s0, sf = np.searchsorted( pop.tbins, [t0 + frames - winsize, t0 + frames + winsize] ) C = pop.raster_smoothed[ :, s0:sf ] # Slice into the raser with the new datapoints quad.set_array(C.ravel()) # Update the plotted data # Get the new data for the diaphragm trace s0, sf = np.searchsorted(aux_t, [t0 + frames - winsize, t0 + frames + winsize]) dd.set_data(aux_t[s0:sf] - t0 - frames, aux[s0:sf]) # Update with new slices # Rotate if 3D if len(dims) == 3 and rotate: if frames > rotation_delay: ax.view_init(ax.elev + elev_speed, ax.azim - azim_speed) # Update stim indicators for stim in aux_stims.get_paths(): stim.vertices[:, 0] -= frame_step return trail1, history, quad, dd # This sets up the animation. Pass it "frames" which is a vector from zero to the total duration in desired time steps. "blit" may not always work but it should speed things up. Outside of my scope of knowledge ani = FuncAnimation( f, update, frames=np.arange(0, tf - t0, frame_step), init_func=init, blit=True ) print(f"saving to {fn_out}") ani.save(fn_out, fps=fps, dpi=dpi) # Performs and saves the animation. print("DONE!")
[docs] def make_projection( pop, t0, duration, fn_out, stim_color="C1", intervals=None, cvar=None, cvar_label="", cmap="magma", dims=[0, 1], frame_step=0.01, fps=30, dpi=300, trail_length=0.05, figsize=(4, 4), style="dark_background", projection_kwargs=PROJECTION_KWARGS, history_kwargs=HISTORY_KWARGS, trail_kwargs=TRAIL_KWARGS, baseline=20, mode="line", rotate=True, rotation_delay=1, elev_speed=0.2, azim_speed=0.2, vmin=None, vmax=None, ): """ Make an animation of just the projection with temporal evolution. Args: pop (Population): Population object containing the projection data. t0 (float): Start time of the projection. duration (float): Duration of the projection in seconds. fn_out (str or Path): Output file path for the animation. stim_color (str, optional): Color for the stimulus periods. Defaults to 'C1'. intervals (np.ndarray, optional): Array of intervals (n x 2) for the stimulus periods. Defaults to None. cvar (np.ndarray, optional): Covariate data for coloring the history. Defaults to None. cvar_label (str, optional): Label for the covariate data. Defaults to "". cmap (str, optional): Colormap for the projection. Defaults to "magma". dims (list, optional): Dimensions of the data to be plotted. Defaults to [0, 1]. frame_step (float, optional): Time step between frames in seconds. Defaults to 0.01. fps (int, optional): Frames per second for the animation. Defaults to 30. dpi (int, optional): Dots per inch for the animation. Defaults to 300. trail_length (float, optional): Length of the trail in the projection. Defaults to 0.05. figsize (tuple, optional): Figure size for the animation. Defaults to (4, 4). style (str, optional): Matplotlib style to use. Defaults to "dark_background". projection_kwargs (dict, optional): Additional keyword arguments for the projection. Defaults to PROJECTION_KWARGS. history_kwargs (dict, optional): Additional keyword arguments for the history. Passed to LineCollection. Defaults to HISTORY_KWARGS. trail_kwargs (dict, optional): Additional keyword arguments for the trail. Passed to plt.plot. Defaults to TRAIL_KWARGS. baseline (int, optional): Baseline value for the projection. Defaults to 20. mode (str, optional): Mode for the projection. Defaults to "line". rotate (bool, optional): Whether to rotate the 3D plot. Defaults to True. rotation_delay (float, optional): Delay before starting the rotation. Defaults to 1. elev_speed (float, optional): Speed of elevation rotation. Defaults to 0.2. azim_speed (float, optional): Speed of azimuthal rotation. Defaults to 0.2. vmin (float, optional): Minimum value for the colormap normalization. Defaults to None. vmax (float, optional): Maximum value for the colormap normalization. Defaults to None. Returns: None """ plt.style.use(style) tf = t0 + duration f = plt.figure(figsize=figsize) is_3D = False assert pop.projection is not None, "Projection is not yet computed" if baseline == 0: baseline = 30 projection_kwargs["alpha"] = 0 if len(dims) == 3: is_3D = True if is_3D: ax = f.add_subplot(111, projection="3d") else: ax = f.add_subplot(111) if mode == "line": ax = pop.plot_projection_line( dims=dims, t0=t0 - baseline, tf=t0, ax=ax, vmin=vmin, vmax=vmax, cvar=cvar, colorbar_title=cvar_label, **projection_kwargs, ) elif mode == "scatter": if cvar is not None: projection_kwargs.pop("lw") projection_kwargs.pop("color") ax = pop.plot_projection( dims=dims, t0=t0 - baseline, tf=t0, ax=ax, cvar=cvar, vmin=vmin, vmax=vmax, colorbar_title=cvar_label, **projection_kwargs, )[1] else: ax = pop.plot_projection( dims=dims, t0=t0 - baseline, tf=t0, ax=ax, c=projection_kwargs["color"] )[1] else: raise ValueError('Mode must be "line" or "scatter') # Plot trail s1, s2 = np.searchsorted(pop.tbins, [t0 - trail_length, t0]) (trail1,) = ax.plot( pop.projection[s1:s2, dims[0]], pop.projection[s1:s2, dims[1]], **trail_kwargs, ) # Plot history================================ s0, sf = np.searchsorted(pop.tbins, [t0, tf]) segments = np.stack( [pop.projection[s0 : sf - 1, dims], pop.projection[s0 + 1 : sf, dims]], axis=1 ) if is_3D: history = Line3DCollection(segments, **history_kwargs) else: history = LineCollection(segments, **history_kwargs) ax.add_collection(history) # Set history color according to cvar if cvar is not None: vmin = vmin or np.min(cvar[s0 : sf - 1]) vmax = vmax or np.max(cvar[s0 : sf - 1]) norm = mcolors.Normalize(vmin=vmin, vmax=vmax) cmap = plt.get_cmap(cmap) history.set_color(cmap(norm(cvar[s0 : sf - 1]))) else: history.set_color("none") # Clean up axes ax = _trim_axes(ax, pop, t0, tf, dims) # Add the XY plane in a 3D plot if is_3D: _plot_xy_plane(ax, color=plt.rcParams["text.color"], alpha=0.25) # ax.axis('off') # _plot_all_planes(ax,color=plt.rcParams['text.color'],alpha=0.25) def init(): return (history,) def update(frames): this_t = t0 + frames # Update history if cvar is None: colors = np.array( [trail_kwargs["color"]] * (segments.shape[0]), dtype="object" ) if intervals is not None: for t1, t2 in intervals: mask = (pop.tbins[s0 : sf - 1] >= t1) & ( pop.tbins[s0 : sf - 1] <= t2 ) colors[mask] = stim_color history.set_color(colors) alphas = np.ones(segments.shape[0]) * history_kwargs["alpha"] invisible_mask = pop.tbins[s0 : sf - 1] > this_t alphas[invisible_mask] = 0 history.set_alpha(alphas) # Update trail s1, s2 = np.searchsorted(pop.tbins, [t0 + frames - trail_length, t0 + frames]) trail1.set_data(pop.projection[s1:s2, dims[0]], pop.projection[s1:s2, dims[1]]) if is_3D: trail1.set_3d_properties(pop.projection[s1:s2, dims[2]]) # Color trail by stim is_stim = False if intervals is not None: if np.any( np.logical_and(intervals[:, 0] < this_t, intervals[:, 1] > this_t) ): is_stim = True if is_stim: trail1.set_color(stim_color) else: trail1.set_color(trail_kwargs["color"]) # Sets the color # Rotate if 3D if len(dims) == 3 and rotate: if frames > rotation_delay: ax.view_init(ax.elev + elev_speed, ax.azim - azim_speed) return (history,) ani = FuncAnimation( f, update, init_func=init, frames=np.arange(0, tf - t0, frame_step), blit=True ) print(f"saving to {fn_out}") ani.save(fn_out, fps=fps, dpi=dpi) # Performs and saves the animation. print("DONE!")
[docs] def make_rotating_projection( pop, t0, duration, fn_out, figsize=(4, 4), dims=[0, 1, 2], cvar=None, rotation_delay=1, elev_speed=0.1, azim_speed=0.1, mode="scatter", style="dark_background", cvar_label="", cmap="magma", vmin=None, vmax=None, frame_step=0.01, fps=30, dpi=300, n_frames=100, projection_kwargs=PROJECTION_KWARGS, ): """ Rotate a 3D projection without temporal evolution. Args: pop (Population): Population object containing the projection data. t0 (float): Start time of the projection. duration (float): Duration of the projection in seconds. fn_out (str or Path): Output file path for the animation. figsize (tuple, optional): Figure size for the animation. Defaults to (4, 4). dims (list, optional): Dimensions of the data to be plotted. Defaults to [0, 1, 2]. cvar (np.ndarray, optional): Covariate data for coloring the projection. Defaults to None. rotation_delay (float, optional): Delay before starting the rotation. Defaults to 1. elev_speed (float, optional): Speed of elevation rotation. Defaults to 0.1. azim_speed (float, optional): Speed of azimuthal rotation. Defaults to 0.1. mode (str, optional): Mode for the projection ('line' or 'scatter'). Defaults to "scatter". style (str, optional): Matplotlib style to use. Defaults to "dark_background". cvar_label (str, optional): Label for the covariate data. Defaults to "". cmap (str, optional): Colormap for the projection. Defaults to "magma". vmin (float, optional): Minimum value for the colormap normalization. Defaults to None. vmax (float, optional): Maximum value for the colormap normalization. Defaults to None. frame_step (float, optional): Time step between frames in seconds. Defaults to 0.01. fps (int, optional): Frames per second for the animation. Defaults to 30. dpi (int, optional): Dots per inch for the animation. Defaults to 300. n_frames (int, optional): Number of frames for the animation. Defaults to 100. projection_kwargs (dict, optional): Additional keyword arguments for the projection. Defaults to PROJECTION_KWARGS. Returns: None """ plt.style.use(style) tf = t0 + duration assert len(dims) == 3, "Rotating projection does not make sense without 3D" assert pop.projection is not None, "Projection is not yet computed" f = plt.figure(figsize=figsize) ax = f.add_subplot(111, projection="3d") (line,) = ax.plot(0, 1, ".", alpha=0) s0, sf = np.searchsorted(pop.tbins, [t0, tf]) if cvar is not None: vmin = vmin or np.min(cvar[s0 : sf - 1]) vmax = vmax or np.max(cvar[s0 : sf - 1]) if mode == "line": try: projection_kwargs.pop("s") except KeyError: pass ax = pop.plot_projection_line( dims=dims, t0=t0, tf=tf, ax=ax, cvar=cvar, vmin=vmin, vmax=vmax, colorbar_title=cvar_label, **projection_kwargs, ) elif mode == "scatter": try: projection_kwargs.pop("lw") projection_kwargs.pop("color") except KeyError: pass if cvar is not None: ax = pop.plot_projection( dims=dims, t0=t0, tf=tf, ax=ax, cvar=cvar, colorbar_title=cvar_label, vmin=vmin, vmax=vmax, **projection_kwargs, )[1] else: ax = pop.plot_projection( dims=dims, t0=t0, tf=tf, ax=ax, c=projection_kwargs["color"] )[1] else: raise ValueError('Mode must be "line" or "scatter') ax = _trim_axes(ax, pop, t0, tf, dims) _plot_xy_plane(ax, color=plt.rcParams["text.color"], alpha=0.25) s0, sf = np.searchsorted(pop.tbins, [t0, tf]) def update(frames): # Rotate if frames > rotation_delay: ax.view_init(ax.elev + elev_speed, ax.azim - azim_speed) return (line,) ani = FuncAnimation(f, update, frames=np.arange(n_frames), blit=True) print(f"saving to {fn_out}") ani.save(fn_out, fps=fps, dpi=dpi) # Performs and saves the animation. print("DONE!")
[docs] def _trim_axes(ax, pop, t0, tf, dims): """Trim the limits of the projection plots N.B. Might be uneeded now that the plot module does this already Args: ax (matplotlib.axes): Axis to modify pop (Population): Population object t0 (float): start time of data that is plotted tf (float): end time of data that is plotted dims (list): Dimensions of the data that are plotted Returns: ax (matplotlib.axes): Modifiied axes """ ax.autoscale() ax.set_aspect("equal") s0, sf = np.searchsorted(pop.tbins, [t0, tf]) xlim = ( np.nanmin(pop.projection[s0:sf, dims[0]]), np.nanmax(pop.projection[s0:sf, dims[0]]), ) ylim = ( np.nanmin(pop.projection[s0:sf, dims[1]]), np.nanmax(pop.projection[s0:sf, dims[1]]), ) if len(dims) == 2: ax.set_xlim(xlim) ax.set_ylim(ylim) xticks = ax.get_xticks() yticks = ax.get_yticks() ax.set_xticks([xticks[0], 0, xticks[-1]]) ax.set_yticks([yticks[0], 0, yticks[-1]]) plt.tight_layout() elif len(dims) == 3: zlim = ( np.min(pop.projection[s0:sf, dims[2]]), np.max(pop.projection[s0:sf, dims[2]]), ) view_min = np.floor(np.nanmin([xlim[0], ylim[0], zlim[0]])) view_max = np.ceil(np.nanmax([xlim[1], ylim[1], zlim[1]])) ax.set_xlim([view_min, view_max]) ax.set_ylim([view_min, view_max]) ax.set_zlim([view_min, view_max]) ax.set_xticks([view_min, 0, view_max]) ax.set_yticks([view_min, 0, view_max]) ax.set_zticks([view_min, 0, view_max]) ax.get_proj() plt.subplots_adjust(left=0.2, right=0.8, top=0.9, bottom=0.1) sns.despine(trim=True) return ax
[docs] def _plot_xy_plane(ax, **kwargs): """ Plot a transparent XY plane on a 3D axes. This function plots a transparent XY plane on a 3D axes based on the current x and y limits of the axes. The plane is plotted with the specified keyword arguments. Args: ax (Axes3D): The 3D axes on which to plot the XY plane. **kwargs: Additional keyword arguments to pass to the `plot_surface` method. Returns: None """ xlim = ax.get_xlim() ylim = ax.get_ylim() # Define the grid for the plane based on the current x and y limits x = np.linspace(xlim[0], xlim[1], 10) y = np.linspace(ylim[0], ylim[1], 10) x, y = np.meshgrid(x, y) z = np.zeros_like(x) # Plot the surface (plane) with transparency ax.plot_surface(x, y, z, rstride=100, cstride=100, **kwargs)
[docs] def _plot_all_planes(ax, **kwargs): """ Plot transparent XY, YZ, and XZ planes on a 3D axes. This function plots transparent XY, YZ, and XZ planes on a 3D axes based on the current x, y, and z limits of the axes. The planes are plotted with the specified keyword arguments. Args: ax (Axes3D): The 3D axes on which to plot the planes. **kwargs: Additional keyword arguments to pass to the `plot_surface` method. Returns: None """ xlim = ax.get_xlim() ylim = ax.get_ylim() zlim = ax.get_zlim() # Create grid for the XY plane x = np.linspace(xlim[0], xlim[1], 10) y = np.linspace(ylim[0], ylim[1], 10) x, y = np.meshgrid(x, y) z_xy = np.zeros_like(x) # XY plane at z = 0 # Create grid for the YZ plane y = np.linspace(ylim[0], ylim[1], 10) z = np.linspace(zlim[0], zlim[1], 10) y, z = np.meshgrid(y, z) x_yz = np.zeros_like(y) # YZ plane at x = 0 # Create grid for the XZ plane x = np.linspace(xlim[0], xlim[1], 10) z = np.linspace(zlim[0], zlim[1], 10) x, z = np.meshgrid(x, z) y_xz = np.zeros_like(x) # XZ plane at y = 0 # Plot the XY plane with transparency ax.plot_surface(x, y, z_xy, rstride=100, cstride=100, **kwargs) # Plot the YZ plane with transparency ax.plot_surface(x_yz, y, z, rstride=100, cstride=100, **kwargs) # Plot the XZ plane with transparency ax.plot_surface(x, y_xz, z, rstride=100, cstride=100, **kwargs)
if __name__ == "__main__": # Testing: _log.info("Testing") from brainbox.io.one import SpikeSortingLoader from one.api import One from cibrrig.analysis.population import Population, get_good_spikes
[docs] CACHE_DIR = "/data/hps/assoc/private/medullary/data/alf_data_repo"
one = One(CACHE_DIR) eid = one.search(subject="m2024-40")[0] ssl = SpikeSortingLoader(one, eid=eid) spikes, clusters, channels = ssl.load_spike_sorting(spike_sorter="") spikes, cluster_ids = get_good_spikes(spikes, clusters) pop = Population(spikes.times, spikes.clusters) pop.compute_projection() physiology = one.load_object(eid, "physiology") dia = one.load_object(eid, "diaphragm") pdiff = pop.sync_var(physiology.pdiff, physiology.times) log = one.load_object(eid, "log").to_df() laser = one.load_object(eid, "laser").to_df() starts, stops = log.query('phase=="insp"')[["start_time", "end_time"]].values[0] intervals = laser.query("intervals_0>@starts and intervals_1<@stops")[ ["intervals_0", "intervals_1"] ].values # intervals = np.array([[10.0, 11.0], [12.0, 13.0]]) make_aux_raster_projection_with_stims( pop, intervals, dia.filtered, dia.times, "test.mp4", aux_label="dia", stim_color="c", lead_in=2, duration=4, dpi=100, dims=[0, 2], elev_speed=0.0, azim_speed=0.4, baseline=10, ) trail_kwargs = TRAIL_KWARGS trail_kwargs["alpha"] = 0 intervals = None # make_projection( # pop, # 120, # 4, # "test_projection.mp4", # stim_color="r", # intervals=intervals, # dpi=300, # figsize=(4, 4), # frame_step=0.05, # mode="scatter", # dims=[0, 1], # cvar=pdiff, # cvar_label='Pressure', # baseline=0, # cmap="RdBu_r", # vmin=-1, # vmax=1, # trail_kwargs=trail_kwargs, # ) # projection_kwargs = PROJECTION_KWARGS # projection_kwargs["alpha"] = 1 # projection_kwargs["s"] = 1 # make_rotating_projection( # pop, # 120, # 200, # "test_rotation_pdiff.mp4", # cvar=pdiff, # cvar_label="Pressure", # projection_kwargs=projection_kwargs, # n_frames=10, # elev_speed=1, # vmin=-1, # vmax=1, # mode='scatter' # )