Source code for nept.lfp_filtering

import numpy as np
import scipy.signal
import scipy.stats
import matplotlib.mlab
import nept


[docs]def butter_bandpass(signal, thresh, fs, order=4): """ Filters signal using butterworth filter Parameters ---------- signal : nept.LFP fs : int Eg. 2000. Should get this from experiment-specifics. thresh : tuple With format (lowcut, highcut). Typically (140.0, 250.0) for sharp-wave ripple detection. order : int Default set to 4. Returns ------- filtered_butter : np.array """ signal = np.squeeze(signal) nyquist = 0.5 * fs b, a = scipy.signal.butter(order, [thresh[0]/nyquist, thresh[1]/nyquist], btype='band') filtered_butter = scipy.signal.filtfilt(b, a, signal) return filtered_butter
[docs]def detect_swr_hilbert(lfp, fs, thresh, z_thresh=3, power_thresh=3, merge_thresh=0.02, min_length=0.01): """Finds sharp-wave ripple (SWR) times and indices. Parameters ---------- lfp : nept.LocalFieldPotential fs : int Experiment-specific, something in the range of 2000 typical. thresh : tuple With format (lowcut, highcut). Typically (140.0, 250.0) for sharp-wave ripple detection. z_thres : int or float The default is set to 3 power_thres : int or float The default is set to 3 merge_thres : int or float The default is set to 0.02 min_length : float Any sequence less than this amount is not considered a sharp-wave ripple. The default is set to 0.01. Returns ------- swrs : nept.Epoch Containing nept.LocalFieldPotential for each SWR event """ # Filtering signal with butterworth fitler filtered_butter = butter_bandpass(lfp.data, thresh, fs) # Get LFP power (using Hilbert) and z-score the power # Zero padding to nearest regular number to speed up fast fourier transforms (FFT) computed in the hilbert function. # Regular numbers are composites of the prime factors 2, 3, and 5. hilbert_n = next_regular(lfp.n_samples) power_lfp = np.abs(scipy.signal.hilbert(filtered_butter, N=hilbert_n)) power_lfp = power_lfp[:lfp.n_samples] # removing the zero padding now that the power is computed zpower_lfp = scipy.stats.zscore(power_lfp) # Finding locations where the power changes detect = zpower_lfp > z_thresh detect = np.hstack([0, detect, 0]) # pad to detect first or last element change signal_change = np.diff(detect.astype(int)) start_swr_idx = np.where(signal_change == 1)[0] stop_swr_idx = np.where(signal_change == -1)[0] - 1 # Getting times associated with these power changes start_time = lfp.time[start_swr_idx] stop_time = lfp.time[stop_swr_idx] # Merging ranges that are closer - in time - than the merge_threshold. no_double = start_time[1:] - stop_time[:-1] merge_idx = np.where(no_double < merge_thresh)[0] start_merged = np.delete(start_time, merge_idx + 1) stop_merged = np.delete(stop_time, merge_idx) start_merged_idx = np.delete(start_swr_idx, merge_idx + 1) stop_merged_idx = np.delete(stop_swr_idx, merge_idx) # Removing ranges that are shorter - in time - than the min_length value. swr_len = stop_merged - start_merged short_idx = np.where(swr_len < min_length)[0] start_merged = np.delete(start_merged, short_idx) stop_merged = np.delete(stop_merged, short_idx) start_merged_idx = np.delete(start_merged_idx, short_idx) stop_merged_idx = np.delete(stop_merged_idx, short_idx) # Removing ranges that have powers less than the power_threshold if sufficiently different. if power_thresh > z_thresh: max_z = [] for start_idx, stop_idx in zip(start_merged_idx, stop_merged_idx): max_z.append(np.max(zpower_lfp[start_idx:stop_idx])) max_z = np.array(max_z) z_idx = np.where(max_z < power_thresh)[0] start_merged = np.delete(start_merged, z_idx) stop_merged = np.delete(stop_merged, z_idx) start_merged_idx = np.delete(start_merged_idx, z_idx) stop_merged_idx = np.delete(stop_merged_idx, z_idx) swrs = nept.Epoch(np.array([start_merged, stop_merged])) return swrs
[docs]def next_regular(target): """ Find the next regular number greater than or equal to target. Regular numbers are composites of the prime factors 2, 3, and 5. Also known as 5-smooth numbers or Hamming numbers, these are the optimal size for inputs to fast-fourier transforms (FFTPACK). Parameters ---------- target : positive int Returns ------- match : int Notes ----- This function was taken from the scipy.signal.signaltools module. See http://scipy.org/scipylib/ """ if target <= 6: print(target) return target # Quickly check if it's already a power of 2 if not (target & (target-1)): return target match = float('inf') # Anything found will be smaller p5 = 1 while p5 < target: p35 = p5 while p35 < target: # Ceiling integer division, avoiding conversion to float # (quotient = ceil(target / p35)) quotient = -(-target // p35) # Quickly find next power of 2 >= quotient p2 = 2**((quotient - 1).bit_length()) N = p2 * p35 if N == target: return N elif N < match: match = N p35 *= 3 if p35 == target: return p35 if p35 < match: match = p35 p5 *= 5 if p5 == target: return p5 if p5 < match: match = p5 return match
[docs]def power_in_db(power): """Computes the power in dB for plotting Parameters ---------- power : np.array Returns ------- np.array """ return 10*np.log10(power)
[docs]def mean_psd(perievent_lfps, window, fs): """Computes the mean Power Spectral Density (PSD) of perievent slices Parameters ---------- perievent_lfps : nept.AnalogSignal window : int fs : int Returns ------- freq : np.array power : np.array """ power = np.zeros((window+1, perievent_lfps.dimensions)) for i, lfp in enumerate(perievent_lfps.data.T): power[:, i], freq = matplotlib.mlab.psd( lfp, Fs=fs, NFFT=int(window*2), noverlap=int(window/2)) return freq, np.mean(power, axis=1)
[docs]def mean_csd(perievent_lfp1, perievent_lfp2, window, fs): """Computes the mean Cross-Spectral Density (CSD) between perievent slices Parameters ---------- perievent_lfp1 : nept.AnalogSignal perievent_lfp2 : nept.AnalogSignal window : int fs : int Returns ------- freq : np.array power : np.array """ freq, power = scipy.signal.csd(perievent_lfp1.data.T, perievent_lfp2.data.T, fs=fs, nperseg=window, nfft=int(window*2)) return freq, np.mean(power, axis=0)
[docs]def mean_coherence(perievent_lfp1, perievent_lfp2, window, fs): """Computes the mean coherence between perievent slices Parameters ---------- perievent_lfp1 : nept.AnalogSignal perievent_lfp2 : nept.AnalogSignal window : int fs : int Returns ------- freq : np.array coherence : np.array """ freq, coherence = scipy.signal.coherence( perievent_lfp1.data.T, perievent_lfp2.data.T, fs=fs, nperseg=window, nfft=int(window*2)) return freq, np.mean(coherence, axis=0)
[docs]def mean_coherencegram(perievent_lfp1, perievent_lfp2, dt, window, fs, extend=0.3): """ Computes the mean coherence over time between perievent slices (e.g. "coherencegram" because it's a combination of a coherence and a spectrogram) Parameters ---------- perievent_lfp1 : nept.AnalogSignal perievent_lfp2 : nept.AnalogSignal dt : float window : int fs : int extend : float Defaults to 0.3 Returns ------- timebins : np.array freq : np.array coherencegram : np.array """ timebins = np.arange(perievent_lfp1.time[0], perievent_lfp1.time[-1]+dt, dt) coherencegram = np.zeros((window+1, len(timebins))) for i, (t_start, t_stop) in enumerate(zip(timebins[:-2], timebins[1:-1])): lfp1 = perievent_lfp1.time_slice(t_start-extend, t_stop+extend) lfp2 = perievent_lfp2.time_slice(t_start-extend, t_stop+extend) freq, coherencegram[:, i] = mean_coherence(lfp1, lfp2, window, fs) return timebins, freq, coherencegram