"""
I/O functions with NIDAQ data specifically for some of our physiology needs.
This module is primarily focused on loading and manipulating data from
the NIDAQ files recorded by spikeglx, and passes most computation to the physiology module,
which is more general in scope.
"""
import scipy.integrate
import numpy as np
import scipy.signal as sig
from scipy.ndimage import median_filter
import warnings
import spikeglx
try:
from cibrrig.preprocess import physiology
except ImportError:
import sys
sys.path.append("../")
import physiology
import logging
import re
import matplotlib.pyplot as plt
logging.basicConfig()
[docs]
_log = logging.getLogger(__name__)
_log.setLevel(logging.INFO)
[docs]
def get_triggers(session_path):
"""
Looks through all the NIDQ files to extract the trigger strings.
Args:
session_path (Path): Path to the session directory.
Returns:
list: Sorted list of trigger strings found in the NIDQ files.
"""
ni_files = list(session_path.joinpath("raw_ephys_data").glob("*.nidq.cbin"))
trig_strings = [get_trig_string(x.stem) for x in ni_files]
trig_strings.sort()
return trig_strings
[docs]
def get_trig_string(in_str):
"""
Extract the trigger string from a given input string using regex.
Args:
in_str (str): Input string containing the trigger information.
Returns:
str: Extracted trigger string.
"""
trig_string = re.search(r"t\d{1,3}", in_str).group()
return trig_string
[docs]
def load_mmap(fn):
"""
Load a memory-mapped Nidaq file.
Args:
fn (Path): Path to the Nidaq.bin file.
Returns:
tuple: A tuple containing:
- np.ndarray: Memory-mapped data array.
- dict: Metadata dictionary.
"""
SR = spikeglx.Reader(fn)
mmap = SR.read()[0]
return (mmap, SR.meta)
[docs]
def binary_onsets(x, thresh):
"""
Binarize a signal at the level "thresh" and return the onset and offset indices.
Args:
x (np.ndarray): Input signal.
thresh (float): Threshold value to determine binary state of HIGH (1) or LOW (0).
Returns:
tuple: A tuple containing:
- np.ndarray: Indices of onset samples.
- np.ndarray: Indices of offset samples.
Raises:
ValueError: If the number of onsets does not match the number of offsets.
"""
# Convert the signal to a boolean array based on the threshold
xbool = x > thresh
# Find the onsets and offsets
ons = np.where(np.diff(xbool.astype("int")) == 1)[0]
offs = np.where(np.diff(xbool.astype("int")) == -1)[0]
# Deal with edge cases
if xbool[0]:
offs = offs[1:]
if xbool[-1]:
ons = ons[:-1]
if len(ons) != len(offs):
plt.plot(x)
plt.axhline(thresh)
raise ValueError("Onsets does not match offsets")
return (ons, offs)
[docs]
def get_tvec(dat, sr):
"""
Generate a time vector for a given data array and sampling rate.
Args:
dat (np.ndarray): Data array.
sr (float): Sampling rate.
Returns:
np.ndarray: Time vector.
"""
assert len(dat.shape) == 1, "Input data 'dat' must be one-dimensional."
tvec = np.linspace(0, len(dat) / sr, len(dat))
return tvec
[docs]
def get_tvec_from_fn(fn):
"""
Generate a time vector from a Nidaq file.
Args:
fn (Path): Path to the Nidaq file.
Returns:
np.ndarray: Time vector corresponding to the data in the file.
"""
SR = spikeglx.Reader(fn)
tvec = get_tvec_from_SR(SR)
return tvec
[docs]
def get_tvec_from_SR(SR):
"""
Generate a time vector from a SpikeGLX reader object.
Args:
SR (spikeglx.Reader): SpikeGLX reader object for the recording.
Returns:
np.ndarray: Time vector corresponding to the data in the reader object.
"""
sr = SR.fs
n_samps = SR.ns
tvec = np.linspace(0, n_samps / sr, n_samps)
return tvec
[docs]
def load_ds_pdiff(SR, chan_id, ds_factor=10, inhale_dir=-1):
"""
Load and downsample the pdiff (differential pressure sensor) data.
Args:
SR (spikeglx.Reader): SpikeGLX reader object for the recording.
chan_id (int): Channel ID for the pdiff signal.
ds_factor (int, optional): Downsampling factor. Defaults to 10.
inhale_dir (int, optional): Direction of inhalation. Defaults to -1.
Returns:
tuple: A tuple containing:
- np.ndarray: Downsampled pdiff data.
- float: Downsampled sampling rate.
"""
dat, sr_sub = _extract_ds_chan(SR, chan_id, ds_factor)
dat = dat * inhale_dir
# Do not do any baseline correction on the PDIFF because it is AC.
return (dat, sr_sub)
[docs]
def load_dia_emg(SR, chan_id):
"""
Read the raw diaphragm EMG data. Does not downsample the data
Subtract the mean from the raw data.
Args:
SR (spikeglx.Reader): SpikeGLX reader object for the recording.
chan_id (int): Channel ID for the diaphragm EMG signal.
Returns:
tuple: A tuple containing:
- np.ndarray: Raw diaphragm EMG data.
- float: Sampling rate of the diaphragm recording.
"""
ds_factor = 1 # Do no downsampling here
dat, sr = _extract_ds_chan(SR, chan_id, ds_factor)
dat = dat - np.mean(dat)
return (dat, sr)
[docs]
def filt_int_ds_dia(x, sr, ds_factor=10, rel_height=0.95, heartbeats=None):
"""
Filter, integrate, and downsample the diaphragm EMG signal. Detect and summarize the diaphragm bursts.
Uses median filtering to smooth the signal, which can be slow but is effective.
Args:
x (np.ndarray): Raw diaphragm EMG signal.
sr (float): Sampling rate of the input signal.
ds_factor (int, optional): Downsampling factor. Defaults to 10.
rel_height (float, optional): Relative height for burst detection. Defaults to 0.95.
heartbeats (np.ndarray, optional): Precomputed heartbeats. Defaults to None.
Returns:
tuple: A tuple containing:
- pd.DataFrame: DataFrame with burst statistics.
- np.ndarray: Downsampled and normalized diaphragm signal.
- float: Downsampled sampling rate.
- np.ndarray: Heart rate data.
- np.ndarray: Filtered diaphragm signal.
- np.ndarray: Detected heartbeats.
"""
assert type(ds_factor) is int
# Remove the EKG artifact
_log.info("Removing the EKG...")
dia_filt, pulse = physiology.remove_EKG(x, sr, thresh=2, heartbeats=heartbeats)
dia_filt[np.isnan(dia_filt)] = np.nanmedian(dia_filt)
# Filter for high frequency signal
sos = sig.butter(2, [300 / sr / 2, 5000 / sr / 2], btype="bandpass", output="sos")
dia_filt = sig.sosfilt(sos, dia_filt)
# Use medfilt to get the smoothed rectified EMG
_log.info("Smoothing the rectified trace...")
window_length = int(0.05 * np.round(sr)) + 1
if window_length % 2 == 0:
window_length += 1
dd = median_filter(np.abs(dia_filt), window_length)
# Smooth it out a little more
window_length = int(0.01 * np.round(sr)) + 1
if window_length % 2 == 0:
window_length += 1
dia_smooth = sig.savgol_filter(dd, window_length=window_length, polyorder=1)
# Downsample because we don't need this at the original smapling rate
dia_sub = dia_smooth[::ds_factor]
sr_sub = sr / ds_factor
# Get the burst statistics
warnings.filterwarnings("ignore")
dia_df = physiology.burst_stats_dia(dia_sub, sr_sub, rel_height=rel_height)
warnings.filterwarnings("default")
# Compute heart rate from diaphragm signal if heartbeats are not provided
HR = None
if heartbeats is None:
HR, heartbeats = physiology.get_hr_from_dia(pulse / ds_factor, dia_df, sr_sub)
# Normalize the integrated diaphragm to a z-score.
dia_df["amp_z"] = dia_df["amp"] / np.std(dia_sub)
dia_sub = dia_sub / np.std(dia_sub)
_log.info("Done processing diaphragm")
return (dia_df, dia_sub, sr_sub, HR, dia_filt, heartbeats)
[docs]
def filt_int_ds_arbitrary(x, sr, ds_factor=10):
"""
Filter, integrate, and downsample an arbitrary signal.
Applies a second order Butterworth bandpass filter between 300 and 5000 Hz.
Args:
x (np.ndarray): Input signal.
sr (float): Sampling rate of the input signal.
ds_factor (int, optional): Downsampling factor. Defaults to 10.
Returns:
tuple: A tuple containing:
- np.ndarray: Processed and downsampled signal.
- float: Downsampled sampling rate.
- np.ndarray: Filtered signal.
"""
assert type(ds_factor) is int
# Filter for high frequency signal
sos = sig.butter(2, [300 / sr / 2, 5000 / sr / 2], btype="bandpass", output="sos")
x_filt = sig.sosfilt(sos, x)
# Use medfilt to get the smoothed rectified EMG
_log.info("Smoothing the rectified trace...")
window_length = int(0.05 * np.round(sr)) + 1
if window_length % 2 == 0:
window_length += 1
dd = median_filter(np.abs(x_filt), window_length)
# Smooth it out a little more
window_length = int(0.01 * np.round(sr)) + 1
if window_length % 2 == 0:
window_length += 1
dia_smooth = sig.savgol_filter(dd, window_length=window_length, polyorder=1)
# Downsample because we don't need this at the original smapling rate
x_sub = dia_smooth[::ds_factor]
sr_sub = sr / ds_factor
# Normalize the integrated diaphragm to a z-score.
x_sub = x_sub / np.std(x_sub)
_log.info("Done processing signal")
return (x_sub, sr_sub, x_filt)