import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.cm import ScalarMappable
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from one.alf.io import AlfBunch
from matplotlib.collections import LineCollection
from matplotlib.ticker import FuncFormatter
from matplotlib.ticker import MaxNLocator
try:
from brainbox.plot import driftmap
from brainbox.processing import bincount2D
import brainbox.singlecell as bbsc
except ImportError:
print("No brainbox package")
has_brainbox = False
from cibrrig.utils.utils import parse_opto_log, validate_intervals, weighted_histogram
# Maps laser wavelengths to hex codes
[docs]
laser_colors = {473: "#00b7ff", 565: "#d2ff00", 635: "#ff0000"}
[docs]
def plot_laser(laser_in, **kwargs):
"""
Flexibly overlay laser stimulation.
Args:
laser_in (AlfBunch or array-like): The laser data to be plotted. Can be an AlfBunch object or an array of intervals.
**kwargs: Additional keyword arguments to be passed to the plotting functions. These may include:
mode (str): The plotting mode. Options are "shade", "bar", "vline", or any other (defaults to steps).
ax (matplotlib.axes.Axes): The axes object to plot on. If None, a new figure and axes will be created.
amp_label (str): Label for the amplitude axis when plotting amplitudes.
wavelength (int): Laser wavelength in nm, used to determine the color of the plot.
alpha (float or array-like): The alpha (transparency) value(s) for shaded areas.
color (str or tuple): The color to use for plotting. If not provided, a default color based on wavelength is used.
query (str): Query string to filter the data (only used in _plot_laser_log).
rotation (int): Rotation angle for text annotations (only used in _plot_laser_log).
fontsize (int): Font size for text annotations (only used in _plot_laser_log).
Any other keyword arguments accepted by matplotlib plotting functions.
Returns:
matplotlib.axes.Axes: The axes object containing the plot.
Notes:
This function determines the appropriate plotting method based on the input type:
- If laser_in is an AlfBunch object with a 'category' key, it calls _plot_laser_log.
- If laser_in is an AlfBunch object without a 'category' key, it calls _plot_laser_alf.
- For other input types, it calls _plot_laser_intervals.
The specific kwargs used may vary depending on which underlying plotting function is called.
"""
if isinstance(laser_in, AlfBunch):
if "category" in laser_in.keys():
ax = _plot_laser_log(laser_in, **kwargs)
else:
ax = _plot_laser_alf(laser_in, **kwargs)
else:
ax = _plot_laser_intervals(laser_in, **kwargs)
return ax
[docs]
def _plot_laser_alf(laser_in, **kwargs):
"""
Plot laser data from a "laser" AlfBunch.
Args:
laser_in (AlfBunch): The AlfBunch object containing laser data.
**kwargs: Additional keyword arguments to be passed to _plot_laser_intervals.
Notes:
This function extracts intervals and amplitudes from the AlfBunch object and calls
_plot_laser_intervals with the appropriate parameters. It determines whether to use
milliwatts or volts for the amplitude label based on the available keys in laser_in.
"""
intervals = laser_in.intervals
if "amplitudesMilliwatts" in laser_in.keys():
amplitudes = laser_in.amplitudesMilliwatts
amp_label = "mW"
elif "amplitudesVolts" in laser_in.keys():
amplitudes = laser_in.amplitudesVolts
amp_label = "command volts"
else:
amplitudes = None
_plot_laser_intervals(intervals, amplitudes, amp_label=amp_label, **kwargs)
[docs]
def _plot_laser_intervals(
intervals,
amplitudes=None,
ax=None,
mode="shade",
amp_label="",
wavelength=473,
alpha=0.2,
**kwargs,
):
"""
Plot laser intervals from arrays.
Args:
intervals (array-like): Array of laser intervals, where each interval is [start_time, end_time].
amplitudes (array-like, optional): Array of amplitude values corresponding to each interval.
ax (matplotlib.axes.Axes, optional): The axes object to plot on. If None, a new figure and axes will be created.
mode (str, optional): The plotting mode. Options are "shade", "bar", "vline", or any other (defaults to steps).
amp_label (str, optional): Label for the amplitude axis when plotting amplitudes.
wavelength (int, optional): Laser wavelength in nm, used to determine the color of the plot. Default is 473.
alpha (float or array-like, optional): The alpha (transparency) value(s) for shaded areas. Default is 0.2.
**kwargs: Additional keyword arguments to be passed to the plotting functions.
Returns:
matplotlib.axes.Axes: The axes object containing the plot.
Notes:
This function supports multiple plotting modes:
- "shade": Shades the intervals on the plot.
- "bar": Plots horizontal bars for each interval.
- "vline": Plots vertical lines at the start of each interval.
- Any other mode defaults to plotting steps of the amplitudes.
The function handles color selection based on the wavelength and can use a list
of alpha values for varying transparency across intervals.
"""
if ax is None:
f = plt.figure()
ax = f.add_subplot(111)
try:
iter(alpha)
alpha_list = True
except Exception:
alpha_list = False
if kwargs.get("color") is not None:
color = kwargs.pop("color")
else:
color = laser_colors[wavelength]
if mode == "shade":
for ii, stim in enumerate(intervals):
aa = alpha[ii] if alpha_list else alpha
ax.axvspan(stim[0], stim[1], color=color, alpha=aa, **kwargs)
elif mode == "bar":
yy = ax.get_ylim()[1]
yy = np.ones_like(intervals[:, 0]) * yy * 0.95
ax.hlines(
yy,
intervals[:, 0],
intervals[:, 1],
color=color,
**kwargs,
)
elif mode == "vline":
y0 = ax.get_ylim()[0]
y0 = np.ones_like(intervals[:, 0]) * y0
y1 = ax.get_ylim()[1]
y1 = np.ones_like(intervals[:, 0]) * y1
ax.vlines(intervals[:, 0], y0, y1, color=color, **kwargs)
else:
# interleave zeros for the offsets
print(f"mode {mode} not found. Plotting as steps")
ax = ax.twinx()
if amplitudes is None:
new_amps = np.vstack(
[np.zeros_like(intervals[:, 0]), np.ones_like(intervals[:, 0])]
).T.ravel()
else:
new_amps = np.vstack([np.zeros_like(amplitudes), amplitudes]).T.ravel()
ax.step(intervals.ravel(), new_amps, color=laser_colors[wavelength], **kwargs)
ax.set_ylabel(amp_label)
plt.xlabel("Time (s)")
return ax
[docs]
def _plot_laser_log(log, query=None, rotation=45, fontsize=6, **kwargs):
"""
Plot laser data from a "log" AlfBunch object.
Args:
log (AlfBunch): The AlfBunch object containing laser log data.
query (str, optional): Query string to filter the data. Default is None.
rotation (int, optional): Rotation angle for text annotations. Default is 45.
fontsize (int, optional): Font size for text annotations. Default is 6.
**kwargs: Additional keyword arguments to be passed to _plot_laser_intervals.
Returns:
matplotlib.axes.Axes: The axes object containing the plot.
Notes:
This function extracts opto data from the log, plots the intervals using _plot_laser_intervals,
and adds text annotations for each interval. It handles both milliwatt and voltage amplitudes.
"""
opto_df = log.to_df().query('category=="opto"')
intervals = opto_df[["start_time", "end_time"]].values
if "amplitude_mw" in opto_df.keys():
amps = opto_df["amplitude_mw"]
amp_units = "mW"
else:
amps = opto_df["amplitude"]
amp_units = "command_volts"
ax = _plot_laser_intervals(
intervals, amplitudes=amps, amp_label=amp_units, **kwargs
)
if query:
opto_df = opto_df.query(query)
for _, rr in opto_df.iterrows():
s = parse_opto_log(rr)
# TODO: Fix text going big
ax.text(
np.mean([rr.start_time, rr.end_time]),
plt.gca().get_ylim()[1],
s,
rotation=rotation,
fontsize=fontsize,
)
return ax
[docs]
def _create_ax(dims, projection=None):
"""
Create a new figure and axis.
Args:
dims (list): Dimensions to plot.
projection (str, optional): Type of projection for 3D plots.
Returns:
tuple: Figure and Axes objects.
"""
f = plt.figure()
ax = f.add_subplot(projection=projection)
return f, ax
[docs]
def _setup_colorbar(ax, p, vmin, vmax, colorbar_title):
"""
Set up a colorbar for the plot.
Args:
ax (Axes): The axes object to add the colorbar to.
p: The plot object to create the colorbar from.
vmin (float): Minimum value for the colorbar.
vmax (float): Maximum value for the colorbar.
colorbar_title (str): Title for the colorbar.
"""
cbar = plt.colorbar(p, ax=ax, pad=0.1, orientation="horizontal", location="top")
cbar.set_label(colorbar_title)
if vmin >= 0:
cbar.set_ticks([vmin, vmax])
else:
# Check if vmin and vmax are close to -pi and pi
if np.isclose(vmin, -np.pi,rtol=0.1) and np.isclose(vmax, np.pi, rtol=0.1):
cbar.set_ticks([-np.pi, 0, np.pi])
cbar.set_ticklabels([r"$\pi$", "0", r"$\pi$"])
else:
cbar.set_ticks([vmin, 0, vmax])
cbar.outline.set_edgecolor("none")
cbar.solids.set_alpha(1)
[docs]
def plot_projection_line_multicondition(
X, tbins, intervals, colors, dims=[0, 1], ax=None, alpha=0.5, lw=1, **kwargs
):
"""
Plot low-D projection with unique coloring for the given intervals.
Args:
X (array): Data to plot.
tbins (array): Time bins.
intervals (array): Start and end times for each condition.
colors (list): Colors for each condition.
dims (list): Dimensions to plot.
ax (Axes, optional): Axes to plot on.
alpha (float): Alpha value for transparency.
lw (float): Line width.
Returns:
Axes: The axes object containing the plot.
"""
validate_intervals(intervals[:, 0], intervals[:, 1], overlap_ok=True)
assert len(colors) == intervals.shape[0]
if ax is None:
_, ax = _create_ax(dims, projection="3d" if len(dims) == 3 else None)
for (t0, tf), cc in zip(intervals, colors):
s0, sf = np.searchsorted(tbins, [t0, tf])
X_sub = X[s0 - 1 : sf, :]
ax = plot_projection_line(
X_sub, dims=dims, cvar=None, color=cc, alpha=alpha, ax=ax, lw=lw, **kwargs
)
return ax
[docs]
def plot_projection_line(X, cvar=None, dims=[0, 1], cmap="viridis", **kwargs):
"""
Plot low-d projection as a line. Optionally
Args:
X (array): Data to plot.
cvar (array, optional): Color variable.
dims (list): Dimensions to plot.
cmap (str): Colormap to use.
**kwargs: Additional keyword arguments.
Keyword Args:
ax (matplotlib.axes.Axes, optional): The axes to plot on. If not provided, a new figure and axes will be created.
color (str, optional): Color of the line if cvar is not provided. Default is "k" (black).
alpha (float, optional): The alpha blending value, between 0 (transparent) and 1 (opaque). Default is 0.5.
lw (float, optional): The line width. Default is 0.5.
vmin (float, optional): Minimum of the colormap range. If not provided, it's inferred from cvar.
vmax (float, optional): Maximum of the colormap range. If not provided, it's inferred from cvar.
colorbar_title (str, optional): Title for the colorbar. Default is an empty string.
title (str, optional): Title for the plot. Only used in 3D plots. Default is an empty string.
lims (list, optional): The x, y, (and z for 3D) limits of the plot as [min, max]. Only used in 3D plots. Default is [-4, 4].
pane_color (color, optional): Color of the panes in 3D plots. If None, default matplotlib style is used.
plot_colorbar (bool, optional): Whether to plot the colorbar. Only used in 3D plots. Default is True.
Returns:
Axes: The axes object containing the plot.
"""
if len(dims) == 2:
return _plot_projection_line_2D(X, cvar, dims, cmap=cmap, **kwargs)
elif len(dims) == 3:
ax = _plot_projection_line_3D(X, cvar, dims=dims, cmap=cmap, **kwargs)
else:
raise ValueError("Number of dims must be two or three")
return ax
[docs]
def _plot_projection_line_2D(
X,
cvar=None,
dims=[0, 1],
cmap="viridis",
color="k",
ax=None,
alpha=0.5,
vmin=None,
vmax=None,
lw=0.5,
colorbar_title="",
plot_colorbar=True,
**kwargs,
):
"""
Plot 2D projection line.
Args:
X (array): Data to plot.
cvar (array, optional): Color variable.
dims (list): Dimensions to plot.
cmap (str): Colormap to use.
color (str): Color for the line if cvar is None.
ax (Axes, optional): Axes to plot on.
alpha (float): Alpha value for transparency.
vmin (float, optional): Minimum value for colormap.
vmax (float, optional): Maximum value for colormap.
lw (float): Line width.
colorbar_title (str): Title for the colorbar.
**kwargs: Additional keyword arguments.
Returns:
Axes: The axes object containing the plot.
"""
if ax is None:
_, ax = _create_ax(dims)
segments = np.stack([X[:-1, dims], X[1:, dims]], axis=1)
use_arrow = kwargs.pop("use_arrow", None)
mutation_scale = kwargs.pop("mutation_scale", 10)
# TODO: implement an arrow in the middle of the line
multi_arrow = kwargs.pop("multi_arrow", False)
if use_arrow:
if segments.shape[0] == 0:
use_arrow = False
elif segments.shape[0] < 2:
multi_arrow = False
else:
(
a,
b,
) = segments[-1]
arrow = FancyArrowPatch(
a, b, arrowstyle="-|>", color=color, lw=lw, alpha=alpha, mutation_scale=mutation_scale
)
segments = segments[:-1]
if multi_arrow:
(a,b) = segments[1]
arrow2 = FancyArrowPatch(
a, b, arrowstyle="-|>", color=color, lw=lw, alpha=alpha, mutation_scale=mutation_scale
)
_ = kwargs.pop("s", None)
lc = LineCollection(segments, alpha=alpha, lw=lw, **kwargs)
if cvar is not None:
if vmin is None:
vmin = np.min(cvar)
if vmax is None:
vmax = np.max(cvar)
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
lc.set_array(cvar)
lc.set_cmap(cmap)
lc.set_norm(norm)
if plot_colorbar:
_setup_colorbar(ax, lc, vmin, vmax, colorbar_title)
else:
lc.set_color(color)
ax.add_collection(lc)
if use_arrow:
ax.add_patch(arrow)
if multi_arrow:
ax.add_patch(arrow2)
sns.despine()
ax.autoscale()
ax.set_aspect("equal")
ax.set_xlabel(f"Dim {dims[0]+1}")
ax.set_ylabel(f"Dim {dims[1]+1}")
# if cvar is not None and plot_colorbar:
# cbar = plt.colorbar(
# lc, ax=ax, pad=0.1, orientation="horizontal", location="top"
# )
# cbar.set_label(colorbar_title)
# cbar.set_ticks([vmin, 0, vmax])
# cbar.solids.set_alpha(1)
return ax
[docs]
def _plot_projection_line_3D(
X,
cvar=None,
dims=[0, 1, 2],
cmap="viridis",
color="k",
ax=None,
title="",
alpha=0.5,
lims=None,
pane_color=None,
colorbar_title="",
plot_colorbar=True,
vmin=None,
vmax=None,
lw=0.5,
**kwargs,
):
"""
Plot 3D projection line.
Args:
X (array): Data to plot.
cvar (array, optional): Color variable.
dims (list): Dimensions to plot.
cmap (str): Colormap to use.
color (str): Color for the line if cvar is None.
ax (Axes3D, optional): 3D axes to plot on.
title (str): Title for the plot.
alpha (float): Alpha value for transparency.
lims (list): Limits for the axes.
pane_color: Color for the panes.
colorbar_title (str): Title for the colorbar.
plot_colorbar (bool): Whether to plot the colorbar.
vmin (float, optional): Minimum value for colormap.
vmax (float, optional): Maximum value for colormap.
lw (float): Line width.
**kwargs: Additional keyword arguments.
Returns:
Axes3D: The 3D axes object containing the plot.
"""
if ax is None:
_, ax = _create_ax(dims, projection="3d")
segments = np.stack([X[:-1, dims], X[1:, dims]], axis=1)
lc = Line3DCollection(segments, alpha=alpha, lw=lw, **kwargs)
if cvar is not None:
vmin = vmin or np.min(cvar)
vmax = vmax or np.max(cvar)
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
colors = plt.get_cmap(cmap)(norm(cvar[:-1]))
lc.set_color(colors)
if plot_colorbar:
sm = ScalarMappable(cmap=cmap, norm=norm)
sm.set_array(colors)
_setup_colorbar(ax, sm, vmin, vmax, colorbar_title)
else:
lc.set_color(color)
ax.add_collection(lc)
ax.autoscale()
if lims is None:
lim = np.nanmax(np.abs(X[:, dims]))
lims = [-lim, lim]
_clean_3d_axes(ax, title, dims, pane_color, lims=lims)
return ax
[docs]
def plot_projection(X, dims, **kwargs):
"""
Plot projection in 2D or 3D.
Args:
X (array): Data to plot.
dims (list): Dimensions to plot.
**kwargs: Additional keyword arguments.
Keyword Args:
ax (matplotlib.axes.Axes, optional): The axes to plot on. If not provided, a new figure and axes will be created.
color (str, optional): The color of the data points or lines. Default is "k" (black) if no color variable (`cvar`) is provided.
alpha (float, optional): Transparency level for the points. Should be between 0 (fully transparent) and 1 (fully opaque).
- Default is 0.5.
lw (float, optional): Line width for plotting. Applies if plotting a line plot. Default is 0.5.
vmin (float, optional): Minimum value for the colormap if a color variable (`cvar`) is used. If not specified, it's inferred from `cvar`.
vmax (float, optional): Maximum value for the colormap if a color variable (`cvar`) is used. If not specified, it's inferred from `cvar`.
colorbar_title (str, optional): Title for the colorbar. Default is an empty string.
title (str, optional): Title for the plot. Only used in 3D plots. Default is an empty string.
lims (list, optional): Axis limits for the plot. Should be a list of `[min, max]` values.
- Default is [-4, 4] for 3D plots.
pane_color (str, optional): Background color of the 3D panes in 3D plots. If `None`, default style is used.
plot_colorbar (bool, optional): Whether to include a colorbar in the plot. Only applies to 3D plots when `cvar` is used. Default is True.
s (float, optional): Size of the markers used in the scatter plot. Default is 1.
cmap (str, optional): Colormap to use when plotting with a color variable (`cvar`). Default is "viridis".
cvar (array-like, optional): An array of values used to color the data points. If provided, `cmap` is applied.
s (float, optional): Size of the markers in the scatter plot. Default is 1.
Returns:
tuple: Figure and Axes objects.
"""
if len(dims) == 2:
return plot_2D_projection(X, dims, **kwargs)
elif len(dims) == 3:
return plot_3D_projection(X, dims, **kwargs)
else:
raise ValueError(f"Number of plotted dimensions must be 2 or 3. {dims=}")
[docs]
def plot_3D_projection(
X,
dims=[0, 1, 2],
cvar=None,
ax=None,
title="",
s=1,
vmin=None,
vmax=None,
cmap="viridis",
c="k",
alpha=0.2,
lims=None,
plot_colorbar=True,
colorbar_title="",
pane_color=None,
**kwargs,
):
"""
Plot 3D projection.
Args:
X (array): Data to plot.
dims (list): Dimensions to plot.
cvar (array, optional): Color variable.
ax (Axes3D, optional): 3D axes to plot on.
title (str): Title for the plot.
s (float): Size of the markers.
vmin (float, optional): Minimum value for colormap.
vmax (float, optional): Maximum value for colormap.
cmap (str): Colormap to use.
c (str): Color for the markers if cvar is None.
alpha (float): Alpha value for transparency.
lims (list): Limits for the axes.
plot_colorbar (bool): Whether to plot the colorbar.
colorbar_title (str): Title for the colorbar.
pane_color: Color for the panes.
**kwargs: Additional keyword arguments.
Returns:
tuple: Figure and Axes3D objects.
"""
assert len(dims) == 3, f"Must choose 3 dimensions to plot. Chose {dims}"
if ax is None:
_, ax = _create_ax(dims, projection="3d")
if cvar is None:
p = ax.scatter(
X[:, dims[0]],
X[:, dims[1]],
X[:, dims[2]],
c=cvar,
s=s,
alpha=alpha,
cmap=cmap,
vmin=vmin,
vmax=vmax,
**kwargs,
)
else:
p = ax.scatter(
X[:, dims[0]],
X[:, dims[1]],
X[:, dims[2]],
c=c,
s=s,
alpha=alpha,
cmap=None,
vmin=vmin,
vmax=vmax,
**kwargs,
)
if cvar is not None and plot_colorbar:
_setup_colorbar(
ax, p, vmin or np.min(cvar), vmax or np.max(cvar), colorbar_title
)
ax.autoscale()
if lims is None:
lim = np.nanmax(np.abs(X[:, dims]))
lims = [-lim, lim]
_clean_3d_axes(ax, title, dims, pane_color, lims=lims)
return ax
[docs]
def plot_2D_projection(
X,
dims=[0, 1],
cvar=None,
ax=None,
title="",
s=1,
vmin=None,
vmax=None,
cmap="viridis",
c="C1",
alpha=0.2,
lims=[-4, 4],
plot_colorbar=True,
colorbar_title="",
):
"""
Plot 2D projection.
Args:
X (array): Data to plot.
dims (list): Dimensions to plot.
cvar (array, optional): Color variable.
ax (Axes, optional): Axes to plot on.
title (str): Title for the plot.
s (float): Size of the markers.
vmin (float, optional): Minimum value for colormap.
vmax (float, optional): Maximum value for colormap.
cmap (str): Colormap to use.
c (str): Color for the markers if cvar is None.
alpha (float): Alpha value for transparency.
lims (list): Limits for the axes.
plot_colorbar (bool): Whether to plot the colorbar.
colorbar_title (str): Title for the colorbar.
Returns:
tuple: Figure and Axes objects.
"""
assert len(dims) == 2, f"Must choose 2 dimensions to plot. Chose {dims}"
if ax is None:
_, ax = _create_ax(dims)
if cvar is None:
p = ax.scatter(
X[:, dims[0]],
X[:, dims[1]],
c=c,
s=s,
alpha=alpha,
cmap=None,
vmin=vmin,
vmax=vmax,
)
else:
p = ax.scatter(
X[:, dims[0]],
X[:, dims[1]],
c=cvar,
s=s,
alpha=alpha,
cmap=cmap,
vmin=vmin,
vmax=vmax,
)
if cvar is not None and plot_colorbar:
_setup_colorbar(
ax, p, vmin or np.min(cvar), vmax or np.max(cvar), colorbar_title
)
ax.set_title(title)
ax.autoscale()
ax.set_aspect("equal")
ax.set_xlabel(f"Dim {dims[0]+1}")
ax.set_ylabel(f"Dim {dims[1]+1}")
ax.spines[["right", "top"]].set_visible(False)
return ax.get_figure(), ax
[docs]
def plot_polar_average(
x,
y,
t,
ax=None,
t0=None,
tf=None,
color="k",
bins=50,
multi="sem",
alpha=0.3,
**plot_kwargs,
):
"""
Plot covariate `y` as a function of phase `x` on a polar.
If t0,tf are arrays, will average over multiple intervals
Args:
x (1D numpy array): Phase data with values in the range [-pi, pi].
y (1D numpy array): Signal data to be plotted against `x`.
t (1D numpy array): Time data corresponding to `x` and `y`.
ax (matplotlib.axes.Axes, optional): The axes object to plot on. If None, a new figure and axes are created. Defaults to None.
t0 (int, float, or list, optional): Start time(s) for epoch selection. If a list or array, averages over multiple epochs are computed. Defaults to None.
tf (int, float, or list, optional): End time(s) for epoch selection, matching the format of `t0`. Defaults to None.
color (str or list, optional): Line color(s) for the plot. Defaults to 'k'.
bins (int, optional): Number of bins for the polar histogram. Defaults to 50.
multi (str, optional): Specifies the method for calculating the shaded region. Options are 'std' for standard deviation or 'sem' for standard error of the mean. Defaults to 'sem'.
alpha (float, optional): Transparency of the shaded region. Defaults to 0.3.
**plot_kwargs: Additional keyword arguments passed to `ax.plot`.
Returns:
tuple: A tuple containing:
- `f`: The created figure object (or None if `ax` was provided).
- `ax`: The axes object used for plotting.
- `y_polar_out`: 2D numpy array of the polar data averaged over epochs.
- `phase_bins`: Phase bin centers for the plot.
Example:
>>> plot_polar_average(x, y, t, t0=0, tf=10, color='b', bins=30, multi='std')
"""
try:
iter(t0)
except Exception:
t0 = [t0]
try:
iter(tf)
except Exception:
tf = [tf]
if type(color) is not list:
color = [color]
assert len(t0) == len(tf), f"{len(t0)=} and {len(tf)=}; they must have same shape"
y_polar_out = []
for ii, (start, stop) in enumerate(zip(t0, tf)):
s0, sf = np.searchsorted(t, [start, stop])
phase_bins, y_polar = weighted_histogram(
x[s0:sf], y[s0:sf], bins=bins, wrap=True
)
y_polar_out.append(y_polar)
y_polar_out = np.vstack(y_polar_out)
m = np.mean(y_polar_out, 0)
# Plotting
if ax is None:
f = plt.figure()
ax = f.add_subplot(projection="polar")
else:
f = None
if multi == "sem":
lb = m - np.nanstd(y_polar_out, 0) / np.sqrt(y_polar_out.shape[0])
ub = m + np.nanstd(y_polar_out, 0) / np.sqrt(y_polar_out.shape[0])
ax.plot(phase_bins, np.mean(y_polar_out, 0), color=color[0], **plot_kwargs)
ax.fill_between(phase_bins, lb, ub, color=color[0], alpha=alpha)
elif multi == "std":
lb = m - np.nanstd(y_polar_out, 0)
ub = m + np.nanstd(y_polar_out, 0)
ax.plot(phase_bins, np.mean(y_polar_out, 0), color=color[0], **plot_kwargs)
ax.fill_between(phase_bins, lb, ub, color=color[0], alpha=alpha)
else:
for ii, y_polar in enumerate(y_polar_out):
if len(color) == 1:
c = color[0]
else:
c = color[ii]
ax.plot(phase_bins, y_polar, color=c, **plot_kwargs)
clean_polar_axis(ax)
return (f, ax, y_polar_out, phase_bins)
[docs]
def plot_reset_curve(
breaths,
events,
wavelength=473,
annotate=False,
norm=True,
plot_tgl=True,
n_control=100,
):
"""
Plot a reset curve for optogenetic stimulation, showing phase-dependent effects on breathing cycles.
Args:
breaths (AlfBunch): Breath timing data with attributes: 'times', 'IBI', and 'duration_sec'.
events (np.ndarray): 1D array of stimulation/event times.
wavelength (int, optional): Wavelength of optogenetic stimulus. Defaults to 473.
annotate (bool, optional): If True, add annotations and color overlays to the plot. Defaults to False.
norm (bool, optional): If True, normalizes time to phase (0-1) for plotting. Defaults to True.
plot_tgl (bool, optional): If True, creates a plot; if False, returns computed data. Defaults to True.
n_control (int, optional): Number of random control points for a control distribution. Defaults to 100.
Returns:
tuple:
- `cycle_stim_time`: Normalized/raw times of stimulation relative to breath onset.
- `cycle_duration`: Normalized/raw breath cycle durations following stimulation.
- `cycle_stim_time_rand`: Control times for stimulation from random event times.
- `cycle_duration_rand`: Control breath cycle durations for random events.
"""
def _get_relative_times(times, events):
idx_last = np.searchsorted(times, events) - 1
idx_next = idx_last + 1
return events - times[idx_last], times[idx_next] - events
# Filter breaths within the event range
t0, tf = events.min(), events.max()
valid_breaths = (breaths.times > t0) & (breaths.times < tf)
mean_IBI, mean_dur = (
breaths.IBI[valid_breaths].mean(),
breaths.duration_sec[valid_breaths].mean(),
)
norm_value = mean_IBI if norm else 1
# Calculate random control data
rand_samp = np.random.uniform(low=t0, high=tf, size=n_control)
t_last_rand, t_next_rand = _get_relative_times(breaths.times, rand_samp)
cycle_duration_rand = (t_next_rand + t_last_rand) / norm_value
cycle_stim_time_rand = t_last_rand / norm_value
# Calculate stimulation event data
t_last, t_next = _get_relative_times(breaths.times, events)
cycle_duration = (t_next + t_last) / norm_value
cycle_stim_time = t_last / norm_value
# Plot control data
if plot_tgl:
plt.plot(
cycle_stim_time_rand, cycle_duration_rand, "ko", ms=3, alpha=0.5, mew=0
)
# Plot stimulation event data
if plot_tgl:
plt.plot(
cycle_stim_time,
cycle_duration,
"o",
color=laser_colors[wavelength],
mec="k",
mew=0,
)
# Return computed data if plot is disabled
if not plot_tgl:
return (
cycle_stim_time,
cycle_duration,
cycle_stim_time_rand,
cycle_duration_rand,
)
# Plot aesthetics
def _prettify_plot(norm, mean_dur, mean_IBI):
if norm:
plt.axvline(mean_dur / mean_IBI, color="k", ls="--", lw=0.5)
plt.axhline(1, color="k", ls="--", lw=0.5)
plt.plot([0, 2], [0, 2], color="tab:red")
plt.xlabel("Stim time (normalized)")
plt.ylabel("Cycle duration (normalized)")
plt.xlim(0, 1.5)
plt.ylim(0, 2)
plt.xticks([0, 0.5, 1])
plt.yticks([0, 1, 2])
else:
xmax = np.max(np.concatenate([t_last, t_last_rand]))
ymax = np.max(np.concatenate([t_next, t_next_rand]))
plt.axvline(mean_dur, color="k", ls="--", lw=0.5)
plt.axhline(mean_IBI, color="k", ls="--", lw=0.5)
plt.plot(
[0, mean_dur + mean_IBI], [0, mean_IBI + mean_dur], color="tab:red"
)
plt.xlabel("Time since last breath onset (s)")
plt.ylabel("Total time between breaths (s)")
plt.xlim([0, xmax])
plt.ylim([0, ymax * 1.1])
_prettify_plot(norm, mean_dur, mean_IBI)
# Add annotations and overlays if requested
if annotate:
def _add_annotations(mean_dur, mean_IBI):
plt.text(
0.01, 1.5, "Prolong inspiration", ha="left", va="bottom", rotation=90
)
plt.text(
0.01, 0.01, "Shorten inspiration", ha="left", va="bottom", rotation=90
)
plt.text(
mean_dur / mean_IBI + 0.01,
mean_dur / mean_IBI + 0.05,
"Phase advance",
rotation=90,
)
plt.text(mean_dur / mean_IBI + 0.01, 1.5, "Phase delay", rotation=90)
plt.fill_between(
[0, mean_dur / mean_IBI],
[0, mean_dur / mean_IBI],
[1, 1],
color="tab:purple",
alpha=0.2,
)
plt.fill_between(
[0, mean_dur / mean_IBI], [1, 1], [2, 2], color="tab:green", alpha=0.2
)
pts = np.array(
[
[mean_dur / mean_IBI, 1],
[1, 1],
[1.5, 1.5],
[1.5, 2],
[mean_dur / mean_IBI, 2],
]
)
plt.fill(pts[:, 0], pts[:, 1], color="tab:orange", alpha=0.2)
plt.fill_between(
[mean_dur / mean_IBI, 1],
[mean_dur / mean_IBI, 1],
[1, 1],
color="tab:grey",
alpha=0.2,
)
plt.text(
mean_dur / mean_IBI / 2,
mean_dur / mean_IBI / 2 * 0.8,
"Lower bound",
color="tab:red",
rotation=26,
)
plt.text(
mean_dur / mean_IBI / 2,
plt.gca().get_ylim()[1],
"Inspiration",
ha="center",
va="top",
)
plt.text(
mean_dur / mean_IBI + (1 - mean_dur / mean_IBI) / 2,
plt.gca().get_ylim()[1],
"Expiration",
ha="center",
va="top",
)
_add_annotations(mean_dur, mean_IBI)
sns.despine()
return cycle_stim_time, cycle_duration, cycle_stim_time_rand, cycle_duration_rand
[docs]
def plot_sweeps(xt, x, times, pre, post, ax=None, **kwargs):
"""
Time-aligns a trace `x` to event times specified in `times`.
Args:
xt (array-like): Time values corresponding to the signal trace `x`.
x (array-like): Signal trace data to be plotted.
times (array-like): Event times to align the trace `x` to.
pre (float): Time before each event to start the trace.
post (float): Time after each event to end the trace.
ax (matplotlib.axes.Axes, optional): Axes to plot on. If None, a new figure and axes will be created. Default is None.
**kwargs: Additional keyword arguments passed to `ax.plot`, such as line style or color.
Returns:
ax: matplotlib axis
"""
if ax is None:
f = plt.figure()
ax = f.add_subplot()
for tt in times:
t0 = tt - pre
tf = tt + post
s0, st, sf = np.searchsorted(xt, [t0, tt, tf])-1
ax.plot(xt[s0:sf] - xt[st], x[s0:sf], **kwargs)
return ax
[docs]
def plot_most_likely_dynamics(
model,
xlim=(-4, 4),
ylim=(-3, 3),
nxpts=20,
nypts=20,
alpha=0.8,
ax=None,
figsize=(3, 3),
colors=[f"C{x}" for x in range(7)],
zval=None
):
"""
Plotting of underlying vector fields from Linderman Lab
"""
assert model.D == 2
x = np.linspace(*xlim, nxpts)
y = np.linspace(*ylim, nypts)
X, Y = np.meshgrid(x, y)
xy = np.column_stack((X.ravel(), Y.ravel()))
# Get the probability of each state at each xy location
z = np.argmax(xy.dot(model.transitions.Rs.T) + model.transitions.r, axis=1)
if ax is None:
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
for k, (A, b) in enumerate(zip(model.dynamics.As, model.dynamics.bs)):
dxydt_m = xy.dot(A.T) + b - xy
zk = z == k
if zk.sum(0) > 0:
ax.quiver(
xy[zk, 0],
xy[zk, 1],
dxydt_m[zk, 0],
dxydt_m[zk, 1],
color=colors[k % len(colors)],
alpha=alpha,
)
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
return ax
[docs]
def plot_most_likely_dynamics_3D(
model,
xlim=(-4, 4),
ylim=(-3, 3),
zlim=(-3, 3),
nxpts=10,
nypts=10,
nzpts=10,
alpha=0.2,
ax=None,
figsize=(3, 3),
length=0.2,
colors=[f"C{x}" for x in range(7)],
):
"""
Extension of the linderman vectorfield plot to 3D
"""
assert model.D == 3
x = np.linspace(*xlim, nxpts)
y = np.linspace(*ylim, nypts)
z = np.linspace(*zlim, nzpts)
X, Y, Z = np.meshgrid(x, y, z, indexing="ij")
xyz = np.column_stack((X.ravel(), Y.ravel(), Z.ravel()))
# Get the probability of each state at each xyz location
k_state = np.argmax(xyz.dot(model.transitions.Rs.T) + model.transitions.r, axis=1)
if ax is None:
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection="3d")
for k, (A, b) in enumerate(zip(model.dynamics.As, model.dynamics.bs)):
dxyzdt_m = xyz.dot(A.T) + b - xyz
zk = k_state == k
if zk.sum(0) > 0:
ax.quiver(
xyz[zk, 0],
xyz[zk, 1],
xyz[zk, 2],
dxyzdt_m[zk, 0],
dxyzdt_m[zk, 1],
dxyzdt_m[zk, 2],
color=colors[k % len(colors)],
alpha=alpha,
length=length,
)
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
ax.set_zlabel("$x_3$")
ax.grid(visible=False)
plt.tight_layout()
return ax
[docs]
def _clean_3d_axes(ax, title, dims, pane_color, lims=None):
"""Modify 3D axes to be cleaner:
Set title
set axis labels
make limits equal
turn off grid
set background color
Args:
ax (matplotlib.axes._subplots.Axes3DSubplot): The 3D axes object to customize.
title (str): The title of the plot.
dims (tuple or list of ints): Dimensions to label the axes, corresponding to the 3D data dimensions (e.g., (0, 1, 2) for first three components).
pane_color (tuple or None): RGB color to set for the panes (background of each axis). Use None for default color.
lims (tuple or list of floats,optional): If None, autoscales axes. Axis limits to set for x, y, and z axes (e.g., (-1, 1) to set limits for all axes).
Returns:
ax (matplotlib.axes._subplots.Axes3DSubplot): The modified axes object.
Example:
ax = fig.add_subplot(111, projection='3d')
_clean_3d_axes(ax, "3D Plot", (0, 1, 2), (0.9, 0.9, 0.9, 0.5), (-1, 1))
"""
ax.set_title(title)
if lims is not None:
ax.set_xlim(lims)
ax.set_ylim(lims)
ax.set_zlim(lims)
ax.set_xlabel(f"Dim {dims[0]+1}")
ax.set_ylabel(f"Dim {dims[1]+1}")
ax.set_zlabel(f"Dim {dims[2]+1}")
ax.grid(False)
if pane_color is not None:
ax.xaxis.set_pane_color(pane_color) # Set the color of the x-axis pane
ax.yaxis.set_pane_color(pane_color) # Set the color of the y-axis pane
ax.zaxis.set_pane_color(pane_color) # Set the color of the z-axis pane
return ax
[docs]
def clean_polar_axis(ax):
"""
Clean the appearance of a polar plot.
Use pi/2 (90 degrees) angular ticks, no internal radial ticks, and set labels to mathtext pi
Args:
ax (matplotlib.projections.polar.PolarAxes): The polar axes object to modify.
Example:
ax = plt.subplot(projection='polar')
clean_polar_axis(ax)
"""
ax.set_yticks([ax.get_yticks()[-1]])
ax.set_xticks([0, np.pi / 2, np.pi, np.pi * 3 / 2])
ax.set_xticklabels(["0", "$\\frac{\pi}{2}$", "$\pi$", "$\\frac{-\pi}{2}$"])
[docs]
def clean_linear_radial_axis(ax):
"""
Clean the appearance of a plot with a range o [-pi,pi] but on a normal, linear axis
Sets ticks to every pi/2 interval and uses math text.
Args:
ax (matplotlib.axes._subplots.AxesSubplot): The axes object to modify.
Example:
ax = plt.subplot()
clean_linear_radial_axis(ax)
"""
ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
ax.set_xticklabels(
["$-\pi$", "$\\frac{-\pi}{2}$", "0", "$\\frac{\pi}{2}$", "$\pi$"]
)
sns.despine(trim=True)
[docs]
def plot_driftmap_with_trace(
spike_times,
spike_depths,
trace,
trace_times,
trace_label="",
t0=None,
tf=None,
depth_lim=(None, None),
trace_ylim=(None, None),
t_bin=0.01,
driftmap_kwargs={},
trace_kwargs={},
figsize=(2, 8),
use_scalebar=True,
use_colorbar=True,
cmap=None,
raster_ylabel=None,
):
"""
#TODO: Update documentation
#TODO: UPdate to work more intuitively with clusters
Works well as a driftmap, but lass good as a rastermap
Plot a drift map with an covariate trace above.
Built off of ibllib.brainbox.plot.driftmap.
This function plots a drift map of spike times and depths, with an overlaid trace such as diaphragm or another continuous signal.
Args:
spike_times (array-like): Array of spike times.
spike_depths (array-like): Array of spike depths corresponding to each spike time.
trace (array-like): Continuous signal to overlay on the drift map. Can be multiple columns as long as each row is a timepoint
trace_times (array-like): Time points corresponding to the trace signal.
trace_label (str, optional): Label for the trace. Defaults to ''.
t0 (float, optional): Start time for the plot. Defaults to None, which uses the minimum spike time.
tf (float, optional): End time for the plot. Defaults to None, which uses the maximum spike time.
depth_lim (tuple, optional): Depth limits for the plot. Defaults to (None, None), which uses the min and max spike depths.
trace_ylim (tuple, optional): Y-axis limits for the trace plot. Defaults to (None, None), which uses the min and max trace values.
driftmap_kwargs (dict, optional): Additional keyword arguments for the drift map plot. Defaults to {}.
trace_kwargs (dict, optional): Additional keyword arguments for the trace plot. Defaults to {}.
figsize (tuple, optional): Figure size for the plot. Defaults to (2, 8).
use_scalebar (bool, optional): Whether to use a scale bar in the plot. Defaults to True.
Returns:
matplotlib.axes._subplots.AxesSubplot: The axes object containing the drift map.
matplotlib.axes._subplots.AxesSubplot: The axes object containing the trace plot.
"""
assert len(depth_lim) == 2, "depth_lim must be of length 2"
assert len(trace_ylim) == 2, "trace_ylim must be of length 2"
# set default trace kwargs and overwrite with user input
trace_kwargs_default = {"lw": 0.5}
trace_kwargs_default.update(trace_kwargs)
trace_kwargs = trace_kwargs_default
# set default driftmap kwargs and overwrite with user input
driftmap_kwargs_default = {}
driftmap_kwargs_default.update(driftmap_kwargs)
driftmap_kwargs = driftmap_kwargs_default
# Set time limits for spikes
t0 = t0 or 0
tf = tf or np.nanmax(spike_times)
s0, sf = np.searchsorted(spike_times, [t0, tf])
# Set depth limits
depth_lim = list(depth_lim)
depth_lim[0] = depth_lim[0] or np.min(spike_depths)
depth_lim[1] = depth_lim[1] or np.max(spike_depths)
# Subsample trace
trace_samples = np.logical_and(trace_times >= t0, trace_times <= tf)
trace_subset = trace[trace_samples]
trace_times_subset = trace_times[trace_samples]
trace_ylim = list(trace_ylim)
trace_ylim[0] = trace_ylim[0] or np.min(trace_subset)
trace_ylim[1] = trace_ylim[1] or np.max(trace_subset)
# Make plot
f = plt.figure(figsize=figsize)
gs = f.add_gridspec(nrows=2, ncols=1, height_ratios=[1, 5])
ax_raster = f.add_subplot(gs[1])
ax_trace = f.add_subplot(gs[0], sharex=ax_raster)
ax_trace.plot(trace_times_subset, trace_subset, **trace_kwargs)
driftmap(
spike_times[s0:sf],
spike_depths[s0:sf],
ax=ax_raster,
t_bin=t_bin,
**driftmap_kwargs,
)
ax_raster.set_ylim(depth_lim)
ax_trace.set_ylim(trace_ylim)
ax_trace.set_ylabel(trace_label)
print(ax_raster.get_ylim())
sns.despine()
if cmap is not None:
ax_raster.images[0].set_cmap(cmap)
# Add colorbar
if use_colorbar:
cbar_ax = f.add_axes([0.92, 0.25, 0.05, 0.25]) # [left, bottom, width, height]
cbar = f.colorbar(ax_raster.images[0], cax=cbar_ax, orientation="vertical")
cbar_ax.set_ylabel("spike rate (sp/s)")
cbar.outline.set_linewidth(0)
cbar.ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{x/t_bin:0.0f}"))
cbar.ax.yaxis.set_major_locator(MaxNLocator(nbins=4))
# replace x spine with a horizontal bar for for raster. Horizontal bar should be closeste to [1,2,5,10,30,60]
if use_scalebar:
replace_timeaxis_with_scalebar(ax_raster)
plt.subplots_adjust(hspace=figsize[1] * 0.01)
if raster_ylabel is not None:
ax_raster.set_ylabel(raster_ylabel)
return (ax_raster, ax_trace)
[docs]
def replace_timeaxis_with_scalebar(ax, pad=0.025,lw=None,color=None,size=None):
"""
Replace the x-axis with a horizontal bar showing the time scale of the plot.
Args:
ax (matplotlib.axes._subplots.AxesSubplot): The axes object to modify.
inverted_y (bool, optional): Set true if the plot has 0 at the top
pad (float, optional): Padding between the scale bar and the plot. Defaults to 0.01.
lw (float, optional): Line width for the scale bar. Defaults to None, which uses the default line width.
color (str, optional): Color for the scale bar. Defaults to None, which uses the default text color.
size (float, optional): Font size for the scale bar. Defaults to None, which uses the default tick size.
"""
if lw is None:
lw = plt.rcParams['axes.linewidth']
if color is None:
color = plt.rcParams['text.color']
if size is None:
size = plt.rcParams['xtick.labelsize']
t0, tf = ax.get_xlim()
good_tbars = [0.001, 0.01, 0.05, 0.1, 0.5, 1, 2, 5, 10, 30, 60]
idx = np.searchsorted(good_tbars, (tf - t0) / 5) - 1
tbar_max = t0 + good_tbars[idx]
tbar_length = tbar_max - t0
tbar_label = (
f"{tbar_length*1000:0.0f}ms" if tbar_length < 1 else f"{tbar_length:.0f}s"
)
ymin, ymax = ax.get_ylim()
yrange = ymax - ymin
pad_small = pad * yrange
pad_large = pad_small * 1.1
ax.set_ylim([ymin - pad_large, ymax])
ax.hlines(ymin - pad_small, t0, tbar_max, lw=lw, color=color)
ax.text(
t0,
ymin - pad_large,
tbar_label,
va="top",
ha="left",
color=color,
size=size,
)
ax.set_xticks([])
ax.set_xlabel("")
sns.despine(bottom=True)
[docs]
def trim_yscale_to_lims(ax, ymin, ymax):
ax.set_yticks([ymin, ymax])
sns.despine(bottom=True, trim=True)
if has_brainbox:
[docs]
def plot_peth_and_raster(
spike_times,
starts,
stops=None,
pre_time=0.2,
post_time=0.2,
bin_size=0.01,
smoothing=0,
error_bars="sem",
pethline_kwargs={},
errbar_kwargs={"alpha": 0.5},
eventline_kwargs={"color": plt.rcParams["text.color"], "ls": "--"},
raster_kwargs={"s": 2, "marker": "|"},
raster_ylabel="",
figsize=(3, 6),
subplot_ratio=(1, 5),
):
"""
Plot a peri-event time histogram and raster plot
Args:
spike_times (array): spike times
starts (array): event times
stops (array, optional): stop times. Defaults to None.
pre_time (float, optional): time before event to plot. Defaults to 0.2.
post_time (float, optional): time after event to plot. Defaults to 0.2.
bin_size (float, optional): bin size for histogram. Defaults to 0.01.
smoothing (float, optional): smoothing factor for histogram. Defaults to 0.
error_bars (str, optional): error bars to plot. Defaults to "sem".
pethline_kwargs (dict, optional): kwargs for the peth line. Defaults to {}.
errbar_kwargs (dict, optional): kwargs for the error bars. Defaults to {"alpha": 0.5}.
eventline_kwargs (dict, optional): kwargs for the event line. Defaults to {"color": plt.rcParams["text.color"], "ls": "--"}.
raster_kwargs (dict, optional): kwargs for the raster plot. Defaults to {'s':2,'marker':'|'}.
raster_ylabel (str, optional): ylabel for the raster plot. Defaults to "".
figsize (tuple, optional): figure size. Defaults to (3, 6).
subplot_ratio (tuple, optional): ratio of the subplots. Defaults to (1, 5).
Returns:
ax_raster, ax_peth: matplotlib axis for the raster and peth plots
"""
peth, _ = bbsc.calculate_peths(
spike_times,
np.ones_like(spike_times),
[1],
starts,
pre_time=pre_time,
post_time=post_time,
bin_size=bin_size,
smoothing=smoothing,
)
# Compute spikes separately for raster for better res
peth_high_res, spikes = bbsc.calculate_peths(
spike_times,
np.ones_like(spike_times),
[1],
starts,
pre_time=pre_time,
post_time=post_time,
bin_size=0.001,
smoothing=smoothing,
)
duration_mean = None
if stops is not None:
assert len(starts) == len(stops)
duration_mean = np.mean(stops - starts)
use_error = True
if error_bars == "sem":
sem = peth["stds"][0] / np.sqrt(len(starts))
lb = peth["means"][0] - sem
ub = peth["means"][0] + sem
elif error_bars == "std":
lb = peth["means"][0] - peth["stds"][0]
ub = peth["means"][0] + peth["stds"][0]
else:
use_error = False
f = plt.figure(figsize=figsize)
gs = f.add_gridspec(nrows=2, height_ratios=subplot_ratio)
ax_raster = f.add_subplot(gs[1:, :])
ax_peth = f.add_subplot(gs[0, :], sharex=ax_raster)
ax_peth.plot(peth["tscale"], peth["means"][0], **pethline_kwargs)
if use_error:
ax_peth.fill_between(peth["tscale"], lb, ub, **errbar_kwargs)
cc = peth_high_res["tscale"][np.where(spikes)[2]]
rr = np.where(spikes)[0]
ax_raster.scatter(cc, rr, **raster_kwargs)
ax_peth.set_ylabel("FR (sp/s)")
ax_peth.set_ylim([0, None])
ax_raster.set_ylabel(raster_ylabel)
ax_raster.set_xlabel("Time (s)")
ax_raster.set_ylim([0, len(starts)])
ax_peth.set_xlim([-pre_time, post_time])
ax_raster.axvline(0, **eventline_kwargs)
ax_peth.axvline(0, **eventline_kwargs)
if duration_mean is not None:
ax_peth.axvline(
duration_mean, color=plt.rcParams["text.color"], linestyle="--"
)
ax_raster.axvline(
duration_mean, color=plt.rcParams["text.color"], linestyle="--"
)
plt.tight_layout()
return (ax_raster, ax_peth)
def plot_multicondition_peth(
spike_times,
event_times,
conditions,
pre_time=0.2,
post_time=0.2,
bin_size=0.01,
smoothing=0,
error_bars="sem",
pethline_kwargs={},
errbar_kwargs={"alpha": 0.5},
eventline_kwargs={"color": plt.rcParams["text.color"], "ls": "--"},
ax=None,
colors=None,
):
"""
Plot a peri-event time histogram for multiple conditions
Args:
spike_times (array): spike times
event_times (array): event times
conditions (array): condition labels
pre_time (float, optional): time before event to plot. Defaults to 0.2.
post_time (float, optional): time after event to plot. Defaults to 0.2.
bin_size (float, optional): bin size for histogram. Defaults to 0.01.
smoothing (float, optional): smoothing factor for histogram. Defaults to 0.
error_bars (str, optional): error bars to plot. Defaults to "sem".
pethline_kwargs (dict, optional): kwargs for the peth line. Defaults to {}.
errbar_kwargs (dict, optional): kwargs for the error bars. Defaults to {"alpha": 0.5}.
eventline_kwargs (dict, optional): kwargs for the event line. Defaults to {"color": plt.rcParams["text.color"], "ls": "--"}.
ax (matplotlib axis, optional): axis to plot on. Defaults to None.
colors (list or dict, optional): colors for each condition. Defaults to None.
Returns:
matplotlib axis: axis with the plot
"""
if ax is None:
f, ax = plt.subplots(1, 1)
unique_conditions = list(set(conditions))
# Assign colors
if colors is not None:
if type(colors) is list:
assert len(colors) == len(
set(conditions)
), "Number of colors must match number of conditions"
color_map = {c: col for c, col in zip(unique_conditions, colors)}
elif type(colors) is dict:
color_map = colors
else:
raise ValueError("Colors must be a list or dict")
else:
color_map = {c: f"C{i}" for i, c in enumerate(unique_conditions)}
# Remove color from kwargs if present
pethline_kwargs.pop("color", None)
errbar_kwargs.pop("color", None)
assert (
len(conditions) == len(event_times)
), f"Number of conditions {len(conditions)} must match number of event times {len(event_times)}"
lines = []
for condition in unique_conditions:
idx = np.where(conditions == condition)[0]
peth, _ = bbsc.calculate_peths(
spike_times,
np.ones_like(spike_times),
[1],
event_times[idx],
pre_time=pre_time,
post_time=post_time,
bin_size=bin_size,
smoothing=smoothing,
)
if error_bars == "sem":
sem = peth["stds"][0] / np.sqrt(len(event_times[idx]))
lb = peth["means"][0] - sem
ub = peth["means"][0] + sem
elif error_bars == "std":
lb = peth["means"][0] - peth["stds"][0]
ub = peth["means"][0] + peth["stds"][0]
ll = ax.plot(
peth["tscale"],
peth["means"][0],
label=condition,
color=color_map[condition],
**pethline_kwargs,
)
lines.append(ll[0])
ax.fill_between(
peth["tscale"], lb, ub, color=color_map[condition], **errbar_kwargs
)
ax.axvline(0, **eventline_kwargs)
ax.set_ylabel("FR (sp/s)")
ax.set_xlabel("Time (s)")
ax.legend(lines, unique_conditions)
ax.set_xlim([-pre_time, post_time])
ax.set_ylim([0, None])
return ax