import ctypes
from brainaccess.connect import _dll
import numpy as np
# ctypes
_dll.ba_bci_connect_p300_init.argtypes = [
ctypes.POINTER(ctypes.c_void_p),
ctypes.c_uint8
]
_dll.ba_bci_connect_p300_init.restype = ctypes.c_uint8
_dll.ba_bci_connect_p300_predict.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_double),
ctypes.POINTER(ctypes.c_float)
]
_dll.ba_bci_connect_p300_predict.restype = ctypes.c_uint8
_dll.ba_bci_connect_p300_free.argtypes = [
ctypes.c_void_p
]
_dll.ba_bci_connect_p300_free.restype = None
[docs]class P300:
"""P300 BCI library"""
def __init__(self, repetitions: int) -> None:
"""Initialize P300 model as per selected number of
repetitions.
Parameters
------------
repetitions: int
number of repetitions to consider on each stimuli,
can be either 1 or 3.
Raises
-------
Exception
An error is raised if initializing failed
"""
self._p300 = ctypes.c_void_p()
self.reps: int = repetitions
err = _dll.ba_bci_connect_p300_init(ctypes.pointer(self._p300), repetitions)
if err != 0:
raise Exception("P300 model failed to load")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
self.destroy()
[docs] def destroy(self) -> None:
_dll.ba_bci_connect_p300_free(self._p300)
[docs] def predict(self, x: np.ndarray) -> float:
"""Predict P300
Parameters
------------
x: np.ndarray
data for classifier
Returns
---------
float:
probability that data was P300 event
Raises
-------
Exception
An error is raised if prediction failed
Warnings
----------
Data sampled at 250 Hz must have these properties:
- standardized with ewma and filtered with 1-40 Hz filter
- (8, 176 * repetitions) shape (channels x samples),
- each repetition 200 ms prior to stimulus onset up to 500 ms after stimulus onset
- Channels must be in exactly this order: F3, F4, C3, C4, P3, P4, O1, O2
"""
nchans = x.shape[0]
nsamples = x.shape[1]
if nchans != 8:
raise Exception("8 channels required")
if nsamples != 176 * self.reps:
raise Exception(f"{176 * self.reps} samples required")
_x = x.copy().ravel(order="C").astype(np.double)
c_arr = np.ctypeslib.as_ctypes(_x)
_res = np.zeros(1).astype(np.float32)
c_res = np.ctypeslib.as_ctypes(_res)
_error = _dll.ba_bci_connect_p300_predict(self._p300, c_arr, c_res)
if _error != 0:
raise Exception("P300 prediction failed")
return float(np.ctypeslib.as_array(c_res, shape=(1,)))