Source code for brainaccess.connect.processor

import ctypes
from brainaccess.connect import _dll
import numpy as np


# ctypes

_dll.ba_bci_connect_mean.argtypes = [
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.POINTER(ctypes.c_double),
]
_dll.ba_bci_connect_mean.restype = None

_dll.ba_bci_connect_st_deviation.argtypes = [
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.POINTER(ctypes.c_double),
]
_dll.ba_bci_connect_st_deviation.restype = None


_dll.ba_bci_connect_demean.argtypes = [
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.POINTER(ctypes.c_double),
]
_dll.ba_bci_connect_demean.restype = None

_dll.ba_bci_connect_standartize.argtypes = [
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.POINTER(ctypes.c_double),
]
_dll.ba_bci_connect_standartize.restype = None

_dll.ba_bci_connect_ewma.argtypes = (
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.c_double,
    ctypes.POINTER(ctypes.c_double),
)
_dll.ba_bci_connect_ewma.restype = None

_dll.ba_bci_connect_ewma_standartize.argtypes = (
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.c_double,
    ctypes.c_double,
    ctypes.POINTER(ctypes.c_double),
)
_dll.ba_bci_connect_ewma_standartize.restype = None

_dll.ba_bci_connect_filter_notch.argtypes = (
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.c_double,
    ctypes.c_double,
    ctypes.c_double,
)
_dll.ba_bci_connect_filter_notch.restype = None

_dll.ba_bci_connect_filter_bandpass.argtypes = (
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.c_double,
    ctypes.c_double,
    ctypes.c_double,
)
_dll.ba_bci_connect_filter_bandpass.restype = None

_dll.ba_bci_connect_filter_highpass.argtypes = (
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.c_double,
    ctypes.c_double,
)
_dll.ba_bci_connect_filter_highpass.restype = None

_dll.ba_bci_connect_filter_lowpass.argtypes = (
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.c_double,
    ctypes.c_double,
)
_dll.ba_bci_connect_filter_lowpass.restype = None

_dll.ba_bci_connect_fft.argtypes = [
    ctypes.POINTER(ctypes.c_double),
    ctypes.c_size_t,
    ctypes.c_size_t,
    ctypes.c_double,
    ctypes.POINTER(ctypes.c_double),
    ctypes.POINTER(ctypes.c_double),
]
_dll.ba_bci_connect_fft.restype = None


[docs]def mean(x: np.ndarray) -> np.ndarray: """Calculate mean for each channel in the data Parameters ----------- x: np.ndarray data array, shape (channels, time) Returns -------- np.ndarray means for each channel in the same order as x """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) result = np.zeros(chans) c_result = np.ctypeslib.as_ctypes(result) c_arr = np.ctypeslib.as_ctypes(_x) _dll.ba_bci_connect_mean(c_arr, chans, time_points, c_result) return np.array(c_result[0:chans])
[docs]def st_deviation(x: np.ndarray) -> np.ndarray: """Calculate standard deviation for each channel in the data Parameters ----------- x: np.ndarray data array, shape (channels, time) Returns -------- np.ndarray standard deviation for each channel in the same order as x """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) result = np.zeros(chans) c_result = np.ctypeslib.as_ctypes(result) c_arr = np.ctypeslib.as_ctypes(_x) _dll.ba_bci_connect_st_deviation(c_arr, chans, time_points, c_result) return np.array(c_result[0:chans])
[docs]def demean(x: np.ndarray) -> np.ndarray: """Subtract mean from each channel Parameters ----------- x: np.ndarray data array, shape (channels, time) Returns ----------- np.ndarray data array, shape (channels, time) """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) c_result = np.ctypeslib.as_ctypes(np.zeros(chans * time_points)) c_arr = np.ctypeslib.as_ctypes(_x) _dll.ba_bci_connect_demean(c_arr, chans, time_points, c_result) return np.array(c_result[: chans * time_points]).reshape((chans, time_points))
[docs]def standartize(x: np.ndarray) -> np.ndarray: """Data standardization Parameters ----------- x: np.ndarray data array, shape (channels, time) Returns ----------- np.ndarray data array, shape (channels, time) """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) c_result = np.ctypeslib.as_ctypes(np.zeros(chans * time_points)) c_arr = np.ctypeslib.as_ctypes(_x) _dll.ba_bci_connect_standartize(c_arr, chans, time_points, c_result) return np.array(c_result[: chans * time_points]).reshape((chans, time_points))
[docs]def ewma(x: np.ndarray, alpha: float = 0.001) -> np.ndarray: """Exponential weighed moving average helper_function Parameters ----------- x: np.ndarray data array, shape (channels, time) alpha: float new factor Returns ----------- np.ndarray data array, shape (channels, time) """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) c_result = np.ctypeslib.as_ctypes(np.zeros(chans * time_points)) c_arr = np.ctypeslib.as_ctypes(_x) _dll.ba_bci_connect_ewma(c_arr, chans, time_points, np.float64(alpha), c_result) return np.array(c_result[: chans * time_points]).reshape((chans, time_points))
[docs]def ewma_standardize( x: np.ndarray, alpha: float = 0.001, epsilon: float = 1e-4 ) -> np.ndarray: """Exponential weighed moving average standardization First-order infinite impulse response filter that applies weighting factors which decrease exponentially Parameters ----------- x: np.ndarray data array, shape (channels, time) alpha: float Represents the degree of weighting decrease, a constant smoothing factor between 0 and 1. A higher alpha discounts older observations faster. epsilon: float Stabilizer for division by zero variance Returns ----------- np.ndarray data array, shape (channels, time) """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) c_result = np.ctypeslib.as_ctypes(np.zeros(chans * time_points)) c_arr = np.ctypeslib.as_ctypes(_x) _dll.ba_bci_connect_ewma_standartize( c_arr, chans, time_points, np.float64(alpha), np.float64(epsilon), c_result ) return np.array(c_result[: chans * time_points]).reshape((chans, time_points))
[docs]def filter_notch( x: np.ndarray, sampling_freq: float, center_freq: float, width_freq: float ) -> np.ndarray: """Notch filter at desired frequency Butterworth 4th order zero phase bandpass filter Parameters ----------- x: np.ndarray data array, shape (channels, time) sampling_freq: float data sampling rate center_freq: float notch filter center frequency width_freq: float notch filter width Returns ----------- np.ndarray data array, shape (channels, time) Warnings --------- Data must be detrended or passed through ba_bci_connect_ewma_standartize before applying notch filter """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) c_arr = np.ctypeslib.as_ctypes(_x) _dll.ba_bci_connect_filter_notch( c_arr, ctypes.c_size_t(chans), ctypes.c_size_t(time_points), ctypes.c_double(sampling_freq), ctypes.c_double(center_freq), ctypes.c_double(width_freq), ) return np.array(c_arr[: chans * time_points]).reshape((chans, time_points))
[docs]def filter_bandpass( x: np.ndarray, sampling_freq: float, freq_low: float, freq_high: float ) -> np.ndarray: """Bandpass filter Butterworth 5th order zero phase bandpass filter Parameters ----------- x: np.ndarray data array, shape (channels, time) sampling_freq: float data sampling rate freq_low: float frequency to filter from freq_high: float frequency to filter to Returns ----------- np.ndarray filtered data, shape (channels, time) Warnings --------- Data must be detrended or passed through ba_bci_connect_ewma_standartize before applying notch filter """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) c_arr = np.ctypeslib.as_ctypes(_x) _dll.ba_bci_connect_filter_bandpass( c_arr, ctypes.c_size_t(chans), ctypes.c_size_t(time_points), ctypes.c_double(sampling_freq), ctypes.c_double(freq_low), ctypes.c_double(freq_high), ) return np.array(c_arr[: chans * time_points]).reshape((chans, time_points))
[docs]def filter_highpass(x: np.ndarray, sampling_freq: float, freq: float) -> np.ndarray: """High-pass filter Butterworth 5th order zero phase high-pass filter Parameters ----------- x: np.ndarray data array, shape (channels, time) sampling_freq: float data sampling rate freq: float edge frequency Returns ----------- np.ndarray filtered data, shape (channels, time) Warnings --------- Data must be detrended or passed through ba_bci_connect_ewma_standartize before applying notch filter """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) c_arr = np.ctypeslib.as_ctypes(_x) _dll.ba_bci_connect_filter_highpass( c_arr, ctypes.c_size_t(chans), ctypes.c_size_t(time_points), ctypes.c_double(sampling_freq), ctypes.c_double(freq), ) return np.array(c_arr[: chans * time_points]).reshape((chans, time_points))
[docs]def filter_lowpass(x: np.ndarray, sampling_freq: float, freq: float) -> np.ndarray: """Low-pass filter Butterworth 5th order zero phase low-pass filter Parameters ----------- x: np.ndarray data array, shape (channels, time) sampling_freq: float data sampling rate freq: float edge frequency Returns ----------- np.ndarray filtered data, shape (channels, time) Warnings --------- Data must be detrended or passed through ba_bci_connect_ewma_standartize before applying notch filter """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) c_arr = np.ctypeslib.as_ctypes(_x) _dll.ba_bci_connect_filter_lowpass( c_arr, ctypes.c_size_t(chans), ctypes.c_size_t(time_points), ctypes.c_double(sampling_freq), ctypes.c_double(freq), ) return np.array(c_arr[: chans * time_points]).reshape((chans, time_points))
[docs]def fft(x: np.ndarray, sampling_freq: float) -> dict: """Compute the discrete Fourier Transform (DFT) with the efficient Fast Fourier Transform (FFT) algorithm Parameters ----------- x: np.ndarray data array, shape (channels, time) sampling_freq: float data sampling rate Returns ----------- dict dictionary (key: value) - freq: frequencies - mag: amplitudes - phase: phases """ chans = x.shape[0] time_points = x.shape[1] _x = x.copy().ravel(order="C").astype(np.float64) c_arr = np.ctypeslib.as_ctypes(_x) n_time_steps = (time_points - (time_points % 2)) // 2 + 1 c_result_mag = np.ctypeslib.as_ctypes(np.zeros(chans * n_time_steps)) c_result_phase = np.ctypeslib.as_ctypes(np.zeros(chans * n_time_steps)) _dll.ba_bci_connect_fft( c_arr, chans, time_points, sampling_freq, c_result_mag, c_result_phase ) freqs = np.linspace(0, sampling_freq / 2, n_time_steps) mags = np.array(c_result_mag[: chans * n_time_steps]).reshape((chans, n_time_steps)) phases = np.array(c_result_phase[: chans * n_time_steps]).reshape( (chans, n_time_steps) ) return {"freq": freqs, "mag": mags*2, "phase": phases}