Source code for brainaccess.connect.P300

import ctypes
from brainaccess.connect import _dll
import numpy as np

# ctypes

_dll.ba_bci_connect_p300_init.argtypes = [
    ctypes.POINTER(ctypes.c_void_p),
    ctypes.c_uint8
]
_dll.ba_bci_connect_p300_init.restype = ctypes.c_uint8

_dll.ba_bci_connect_p300_predict.argtypes = [
    ctypes.c_void_p,
    ctypes.POINTER(ctypes.c_double),
    ctypes.POINTER(ctypes.c_float)
]
_dll.ba_bci_connect_p300_predict.restype = ctypes.c_uint8

_dll.ba_bci_connect_p300_free.argtypes = [
    ctypes.c_void_p
]
_dll.ba_bci_connect_p300_free.restype = None


[docs]class P300: """P300 BCI library""" def __init__(self, repetitions: int) -> None: """Initialize P300 model as per selected number of repetitions. Parameters ------------ repetitions: int number of repetitions to consider on each stimuli, can be either 1 or 3. Raises ------- Exception An error is raised if initializing failed """ self._p300 = ctypes.c_void_p() self.reps: int = repetitions err = _dll.ba_bci_connect_p300_init(ctypes.pointer(self._p300), repetitions) if err != 0: raise Exception("P300 model failed to load") def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback) -> None: self.destroy()
[docs] def destroy(self) -> None: _dll.ba_bci_connect_p300_free(self._p300)
[docs] def predict(self, x: np.ndarray) -> float: """Predict P300 Parameters ------------ x: np.ndarray data for classifier Returns --------- float: probability that data was P300 event Raises ------- Exception An error is raised if prediction failed Warnings ---------- Data sampled at 250 Hz must have these properties: - standardized with ewma and filtered with 1-40 Hz filter - (8, 176 * repetitions) shape (channels x samples), - each repetition 200 ms prior to stimulus onset up to 500 ms after stimulus onset - Channels must be in exactly this order: F3, F4, C3, C4, P3, P4, O1, O2 """ nchans = x.shape[0] nsamples = x.shape[1] if nchans != 8: raise Exception("8 channels required") if nsamples != 176 * self.reps: raise Exception(f"{176 * self.reps} samples required") _x = x.copy().ravel(order="C").astype(np.double) c_arr = np.ctypeslib.as_ctypes(_x) _res = np.zeros(1).astype(np.float32) c_res = np.ctypeslib.as_ctypes(_res) _error = _dll.ba_bci_connect_p300_predict(self._p300, c_arr, c_res) if _error != 0: raise Exception("P300 prediction failed") return float(np.ctypeslib.as_array(c_res, shape=(1,)))