import ctypes
from brainaccess.connect import _dll
import numpy as np
# ctypes
_dll.ba_bci_connect_mean.argtypes = [
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_double),
]
_dll.ba_bci_connect_mean.restype = None
_dll.ba_bci_connect_st_deviation.argtypes = [
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_double),
]
_dll.ba_bci_connect_st_deviation.restype = None
_dll.ba_bci_connect_demean.argtypes = [
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_double),
]
_dll.ba_bci_connect_demean.restype = None
_dll.ba_bci_connect_standartize.argtypes = [
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_double),
]
_dll.ba_bci_connect_standartize.restype = None
_dll.ba_bci_connect_ewma.argtypes = (
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.c_double,
ctypes.POINTER(ctypes.c_double),
)
_dll.ba_bci_connect_ewma.restype = None
_dll.ba_bci_connect_ewma_standartize.argtypes = (
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.c_double,
ctypes.c_double,
ctypes.POINTER(ctypes.c_double),
)
_dll.ba_bci_connect_ewma_standartize.restype = None
_dll.ba_bci_connect_filter_notch.argtypes = (
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.c_double,
ctypes.c_double,
ctypes.c_double,
)
_dll.ba_bci_connect_filter_notch.restype = None
_dll.ba_bci_connect_filter_bandpass.argtypes = (
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.c_double,
ctypes.c_double,
ctypes.c_double,
)
_dll.ba_bci_connect_filter_bandpass.restype = None
_dll.ba_bci_connect_filter_highpass.argtypes = (
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.c_double,
ctypes.c_double,
)
_dll.ba_bci_connect_filter_highpass.restype = None
_dll.ba_bci_connect_filter_lowpass.argtypes = (
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.c_double,
ctypes.c_double,
)
_dll.ba_bci_connect_filter_lowpass.restype = None
_dll.ba_bci_connect_fft.argtypes = [
ctypes.POINTER(ctypes.c_double),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.c_double,
ctypes.POINTER(ctypes.c_double),
ctypes.POINTER(ctypes.c_double),
]
_dll.ba_bci_connect_fft.restype = None
[docs]def mean(x: np.ndarray) -> np.ndarray:
"""Calculate mean for each channel in the data
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
Returns
--------
np.ndarray
means for each channel in the same order as x
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
result = np.zeros(chans)
c_result = np.ctypeslib.as_ctypes(result)
c_arr = np.ctypeslib.as_ctypes(_x)
_dll.ba_bci_connect_mean(c_arr, chans, time_points, c_result)
return np.array(c_result[0:chans])
[docs]def st_deviation(x: np.ndarray) -> np.ndarray:
"""Calculate standard deviation for each channel in the data
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
Returns
--------
np.ndarray
standard deviation for each channel in the same order as x
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
result = np.zeros(chans)
c_result = np.ctypeslib.as_ctypes(result)
c_arr = np.ctypeslib.as_ctypes(_x)
_dll.ba_bci_connect_st_deviation(c_arr, chans, time_points, c_result)
return np.array(c_result[0:chans])
[docs]def demean(x: np.ndarray) -> np.ndarray:
"""Subtract mean from each channel
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
Returns
-----------
np.ndarray
data array, shape (channels, time)
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
c_result = np.ctypeslib.as_ctypes(np.zeros(chans * time_points))
c_arr = np.ctypeslib.as_ctypes(_x)
_dll.ba_bci_connect_demean(c_arr, chans, time_points, c_result)
return np.array(c_result[: chans * time_points]).reshape((chans, time_points))
[docs]def standartize(x: np.ndarray) -> np.ndarray:
"""Data standardization
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
Returns
-----------
np.ndarray
data array, shape (channels, time)
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
c_result = np.ctypeslib.as_ctypes(np.zeros(chans * time_points))
c_arr = np.ctypeslib.as_ctypes(_x)
_dll.ba_bci_connect_standartize(c_arr, chans, time_points, c_result)
return np.array(c_result[: chans * time_points]).reshape((chans, time_points))
[docs]def ewma(x: np.ndarray, alpha: float = 0.001) -> np.ndarray:
"""Exponential weighed moving average helper_function
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
alpha: float
new factor
Returns
-----------
np.ndarray
data array, shape (channels, time)
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
c_result = np.ctypeslib.as_ctypes(np.zeros(chans * time_points))
c_arr = np.ctypeslib.as_ctypes(_x)
_dll.ba_bci_connect_ewma(c_arr, chans, time_points, np.float64(alpha), c_result)
return np.array(c_result[: chans * time_points]).reshape((chans, time_points))
[docs]def ewma_standardize(
x: np.ndarray, alpha: float = 0.001, epsilon: float = 1e-4
) -> np.ndarray:
"""Exponential weighed moving average standardization
First-order infinite impulse response filter that applies weighting factors which decrease exponentially
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
alpha: float
Represents the degree of weighting decrease, a constant smoothing factor between 0 and 1. A higher alpha discounts older observations faster.
epsilon: float
Stabilizer for division by zero variance
Returns
-----------
np.ndarray
data array, shape (channels, time)
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
c_result = np.ctypeslib.as_ctypes(np.zeros(chans * time_points))
c_arr = np.ctypeslib.as_ctypes(_x)
_dll.ba_bci_connect_ewma_standartize(
c_arr, chans, time_points, np.float64(alpha), np.float64(epsilon), c_result
)
return np.array(c_result[: chans * time_points]).reshape((chans, time_points))
[docs]def filter_notch(
x: np.ndarray, sampling_freq: float, center_freq: float, width_freq: float
) -> np.ndarray:
"""Notch filter at desired frequency
Butterworth 4th order zero phase bandpass filter
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
sampling_freq: float
data sampling rate
center_freq: float
notch filter center frequency
width_freq: float
notch filter width
Returns
-----------
np.ndarray
data array, shape (channels, time)
Warnings
---------
Data must be detrended or passed through ba_bci_connect_ewma_standartize
before applying notch filter
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
c_arr = np.ctypeslib.as_ctypes(_x)
_dll.ba_bci_connect_filter_notch(
c_arr,
ctypes.c_size_t(chans),
ctypes.c_size_t(time_points),
ctypes.c_double(sampling_freq),
ctypes.c_double(center_freq),
ctypes.c_double(width_freq),
)
return np.array(c_arr[: chans * time_points]).reshape((chans, time_points))
[docs]def filter_bandpass(
x: np.ndarray, sampling_freq: float, freq_low: float, freq_high: float
) -> np.ndarray:
"""Bandpass filter
Butterworth 5th order zero phase bandpass filter
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
sampling_freq: float
data sampling rate
freq_low: float
frequency to filter from
freq_high: float
frequency to filter to
Returns
-----------
np.ndarray
filtered data, shape (channels, time)
Warnings
---------
Data must be detrended or passed through ba_bci_connect_ewma_standartize
before applying notch filter
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
c_arr = np.ctypeslib.as_ctypes(_x)
_dll.ba_bci_connect_filter_bandpass(
c_arr,
ctypes.c_size_t(chans),
ctypes.c_size_t(time_points),
ctypes.c_double(sampling_freq),
ctypes.c_double(freq_low),
ctypes.c_double(freq_high),
)
return np.array(c_arr[: chans * time_points]).reshape((chans, time_points))
[docs]def filter_highpass(x: np.ndarray, sampling_freq: float, freq: float) -> np.ndarray:
"""High-pass filter
Butterworth 5th order zero phase high-pass filter
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
sampling_freq: float
data sampling rate
freq: float
edge frequency
Returns
-----------
np.ndarray
filtered data, shape (channels, time)
Warnings
---------
Data must be detrended or passed through ba_bci_connect_ewma_standartize
before applying notch filter
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
c_arr = np.ctypeslib.as_ctypes(_x)
_dll.ba_bci_connect_filter_highpass(
c_arr,
ctypes.c_size_t(chans),
ctypes.c_size_t(time_points),
ctypes.c_double(sampling_freq),
ctypes.c_double(freq),
)
return np.array(c_arr[: chans * time_points]).reshape((chans, time_points))
[docs]def filter_lowpass(x: np.ndarray, sampling_freq: float, freq: float) -> np.ndarray:
"""Low-pass filter
Butterworth 5th order zero phase low-pass filter
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
sampling_freq: float
data sampling rate
freq: float
edge frequency
Returns
-----------
np.ndarray
filtered data, shape (channels, time)
Warnings
---------
Data must be detrended or passed through ba_bci_connect_ewma_standartize
before applying notch filter
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
c_arr = np.ctypeslib.as_ctypes(_x)
_dll.ba_bci_connect_filter_lowpass(
c_arr,
ctypes.c_size_t(chans),
ctypes.c_size_t(time_points),
ctypes.c_double(sampling_freq),
ctypes.c_double(freq),
)
return np.array(c_arr[: chans * time_points]).reshape((chans, time_points))
[docs]def fft(x: np.ndarray, sampling_freq: float) -> dict:
"""Compute the discrete Fourier Transform (DFT) with the efficient Fast Fourier Transform (FFT) algorithm
Parameters
-----------
x: np.ndarray
data array, shape (channels, time)
sampling_freq: float
data sampling rate
Returns
-----------
dict
dictionary (key: value)
- freq: frequencies
- mag: amplitudes
- phase: phases
"""
chans = x.shape[0]
time_points = x.shape[1]
_x = x.copy().ravel(order="C").astype(np.float64)
c_arr = np.ctypeslib.as_ctypes(_x)
n_time_steps = (time_points - (time_points % 2)) // 2 + 1
c_result_mag = np.ctypeslib.as_ctypes(np.zeros(chans * n_time_steps))
c_result_phase = np.ctypeslib.as_ctypes(np.zeros(chans * n_time_steps))
_dll.ba_bci_connect_fft(
c_arr, chans, time_points, sampling_freq, c_result_mag, c_result_phase
)
freqs = np.linspace(0, sampling_freq / 2, n_time_steps)
mags = np.array(c_result_mag[: chans * n_time_steps]).reshape((chans, n_time_steps))
phases = np.array(c_result_phase[: chans * n_time_steps]).reshape(
(chans, n_time_steps)
)
return {"freq": freqs, "mag": mags*2, "phase": phases}