Basic single trial fNIRS finger tapping classification
This notebook sketches the analysis of a finger tapping dataset with multiple subjects. A simple Linear Discriminant Analysis (LDA) classifier is trained to distinguish left and right fingertapping.
[1]:
import cedalion
import cedalion.nirs
import cedalion.xrutils as xrutils
from cedalion.datasets import get_multisubject_fingertapping_snirf_paths
import numpy as np
import xarray as xr
import pint
import matplotlib.pyplot as p
import scipy.signal
import os.path
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split, cross_validate
from sklearn.metrics import accuracy_score
xr.set_options(display_max_rows=3, display_values_threshold=50)
np.set_printoptions(precision=4)
Loading raw CW-NIRS data from a SNIRF file
This notebook uses a finger-tapping dataset in BIDS layout provided by Rob Luke. It can can be downloaded via cedalion.datasets
.
xarray provides another container: xr.DataSet
. These are collections of xr.DataArray
that share coordinate axes. These can be used to group different arrays with shared coordinate axes together.
[2]:
fnames = get_multisubject_fingertapping_snirf_paths()
subjects = [f"sub-{i:02d}" for i in [1,2,3,4,5]]
# store data of different subjects in a dictionary
data = {}
for subject,fname in zip(subjects, fnames):
elements = cedalion.io.read_snirf(fname)
amp = elements[0].data[0]
stim = elements[0].stim # pandas Dataframe
geo3d = elements[0].geo3d
# cedalion registers an accessor (attribute .cd ) on pandas DataFrames
stim.cd.rename_events( {
"1.0" : "control",
"2.0" : "Tapping/Left",
"3.0" : "Tapping/Right"
})
dpf = xr.DataArray([6, 6], dims="wavelength", coords={"wavelength" : amp.wavelength})
data[subject] = xr.Dataset(
data_vars = {
"amp" : amp,
"od" : - np.log( amp / amp.mean("time") ),
"geo" : geo3d,
"conc": cedalion.nirs.beer_lambert(amp, geo3d, dpf)
},
attrs={"stim" : stim}, # store stimulus data in attrs
coords={"subject" : subject} # add the subject label as a coordinate
)
Downloading file 'multisubject-fingertapping.zip' from 'https://doc.ml.tu-berlin.de/cedalion/datasets/multisubject-fingertapping.zip' to '/home/runner/.cache/cedalion'.
Unzipping contents of '/home/runner/.cache/cedalion/multisubject-fingertapping.zip' to '/home/runner/.cache/cedalion/multisubject-fingertapping.zip.unzip'
Illustrate the dataset of one subject
[3]:
display(data["sub-01"])
<xarray.Dataset> Size: 32MB Dimensions: (time: 23239, channel: 28, wavelength: 2, label: 55, pos: 3, chromo: 2) Coordinates: (3/10) * time (time) float64 186kB 0.0 0.128 0.256 ... 2.974e+03 2.974e+03 samples (time) int64 186kB 0 1 2 3 4 5 ... 23234 23235 23236 23237 23238 ... ... subject <U6 24B 'sub-01' Dimensions without coordinates: pos Data variables: (3/4) amp (channel, wavelength, time) float64 10MB [] 0.09137 ... 1.276 od (channel, wavelength, time) float64 10MB [] 0.04042 ... -0.01317 ... ... conc (chromo, channel, time) float64 10MB [µM] 0.1336 ... -1.076 Attributes: stim: onset duration value trial_type\n0 61.824 ...
Frequency filtering and splitting into epochs
[4]:
for subject, ds in data.items():
# cedalion registers the accessor .cd on DataArrays
# to provide common functionality like frequency filters...
ds["conc_freqfilt"] = ds["conc"].cd.freq_filter(fmin=0.02, fmax=0.5, butter_order=4)
# ... or epoch splitting
ds["cfepochs"] = ds["conc_freqfilt"].cd.to_epochs(
ds.attrs["stim"], # stimulus dataframe
["Tapping/Left", "Tapping/Right"], # select events
before=5, # seconds before stimulus
after=20 # seconds after stimulus
)
Plot frequency filtered data
Illustrate for a single subject and channel the effect of the bandpass filter.
[5]:
ds = data["sub-01"]
channel = "S5D7"
f,ax= p.subplots(2,1, figsize=(12,4), sharex=True)
ax[0].plot(ds.time, ds.conc.sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[0].plot(ds.time, ds.conc.sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[1].plot(ds.time, ds.conc_freqfilt.sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[1].plot(ds.time, ds.conc_freqfilt.sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[0].set_xlim(1000,1200)
ax[1].set_xlabel("time / s")
ax[0].set_ylabel("$\Delta c$ / $\mu M$")
ax[1].set_ylabel("$\Delta c$ / $\mu M$")
ax[0].legend(loc="upper left"); ax[1].legend(loc="upper left");
[6]:
display(data["sub-01"]["cfepochs"])
<xarray.DataArray 'cfepochs' (epoch: 60, chromo: 2, channel: 28, reltime: 196)> Size: 5MB <Quantity([[[[-1.9479e-02 -1.9950e-02 -2.0945e-02 ... -1.9397e-01 -2.1930e-01 -2.4552e-01] [-1.0442e-02 -1.1375e-02 -1.2845e-02 ... -2.5376e-01 -2.5998e-01 -2.6454e-01] [-1.2500e-02 -8.8897e-03 -5.8428e-03 ... -2.1148e-01 -2.2557e-01 -2.3975e-01] ... [ 9.5644e-02 1.0075e-01 1.0499e-01 ... -3.2907e-01 -3.4702e-01 -3.6405e-01] [ 2.0733e-02 2.2404e-02 2.3932e-02 ... -2.9462e-01 -2.9954e-01 -3.0371e-01] [ 1.4324e-03 2.0414e-03 4.1026e-03 ... -3.0974e-01 -3.2710e-01 -3.4287e-01]] [[ 1.4641e-02 6.4301e-03 -1.6854e-03 ... -6.5363e-02 -6.2839e-02 -5.7220e-02] [ 2.0950e-03 5.5313e-03 8.9982e-03 ... -4.8858e-02 -5.4990e-02 -6.0070e-02] [ 4.5053e-02 4.2005e-02 3.9300e-02 ... -4.3936e-02 -4.3039e-02 -4.0861e-02] ... [ 2.8454e-01 2.9324e-01 3.0016e-01 ... 1.0617e-01 1.1339e-01 1.2119e-01] [ 3.4306e-01 3.9592e-01 4.4809e-01 ... -1.8578e-01 -1.7872e-01 -1.7142e-01] [ 6.1801e-01 6.1059e-01 6.0291e-01 ... 1.8868e-01 1.9003e-01 1.9076e-01]] [[ 6.5167e-02 5.7419e-02 4.9134e-02 ... 3.1521e-02 3.2053e-02 3.1760e-02] [ 3.6845e-02 3.4700e-02 3.1450e-02 ... -3.0447e-02 -2.7677e-02 -2.3868e-02] [ 4.8888e-02 3.9980e-02 3.2743e-02 ... 3.0387e-02 3.4824e-02 3.9630e-02] ... [ 5.8174e-02 5.3107e-02 4.7739e-02 ... -8.1363e-03 -6.2960e-03 -5.3040e-03] [ 1.5154e-01 1.6202e-01 1.7299e-01 ... 2.6186e-02 3.2992e-02 3.8859e-02] [ 1.5938e-01 1.5348e-01 1.4687e-01 ... 3.8222e-02 4.2833e-02 4.6774e-02]]]], 'micromolar')> Coordinates: (3/7) * channel (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16' source (channel) object 224B 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8' ... ... trial_type (epoch) object 480B 'Tapping/Left' ... 'Tapping/Right' Dimensions without coordinates: epoch
[7]:
all_epochs = xr.concat([ds["cfepochs"] for ds in data.values()], dim="epoch")
all_epochs
[7]:
<xarray.DataArray 'cfepochs' (epoch: 300, chromo: 2, channel: 28, reltime: 196)> Size: 26MB <Quantity([[[[-1.9479e-02 -1.9950e-02 -2.0945e-02 ... -1.9397e-01 -2.1930e-01 -2.4552e-01] [-1.0442e-02 -1.1375e-02 -1.2845e-02 ... -2.5376e-01 -2.5998e-01 -2.6454e-01] [-1.2500e-02 -8.8897e-03 -5.8428e-03 ... -2.1148e-01 -2.2557e-01 -2.3975e-01] ... [ 9.5644e-02 1.0075e-01 1.0499e-01 ... -3.2907e-01 -3.4702e-01 -3.6405e-01] [ 2.0733e-02 2.2404e-02 2.3932e-02 ... -2.9462e-01 -2.9954e-01 -3.0371e-01] [ 1.4324e-03 2.0414e-03 4.1026e-03 ... -3.0974e-01 -3.2710e-01 -3.4287e-01]] [[ 1.4641e-02 6.4301e-03 -1.6854e-03 ... -6.5363e-02 -6.2839e-02 -5.7220e-02] [ 2.0950e-03 5.5313e-03 8.9982e-03 ... -4.8858e-02 -5.4990e-02 -6.0070e-02] [ 4.5053e-02 4.2005e-02 3.9300e-02 ... -4.3936e-02 -4.3039e-02 -4.0861e-02] ... [-1.8642e-01 -1.8383e-01 -1.8031e-01 ... 9.3582e-03 1.1076e-02 1.3423e-02] [-2.8335e-01 -2.7513e-01 -2.6600e-01 ... -1.3275e-02 -4.6116e-03 2.3698e-03] [-3.7102e-01 -3.6417e-01 -3.5820e-01 ... 1.8997e-01 1.9610e-01 2.0283e-01]] [[ 6.9413e-02 6.7628e-02 6.5086e-02 ... 2.3354e-02 8.2494e-03 -6.5091e-03] [ 9.3748e-02 7.1599e-02 4.8922e-02 ... 6.0442e-03 3.1716e-02 4.8579e-02] [ 1.3260e-01 1.2931e-01 1.2489e-01 ... 1.8629e-01 1.6657e-01 1.4594e-01] ... [ 3.5046e-02 3.1738e-02 2.8594e-02 ... 9.8597e-02 1.0093e-01 1.0222e-01] [ 8.7634e-02 8.3598e-02 8.0126e-02 ... 1.1028e-01 1.0263e-01 9.5370e-02] [-3.3306e-02 -3.6382e-02 -3.7580e-02 ... 4.1337e-03 3.1121e-03 5.3935e-04]]]], 'micromolar')> Coordinates: (3/7) * channel (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16' source (channel) object 224B 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8' ... ... trial_type (epoch) object 2kB 'Tapping/Left' ... 'Tapping/Right' Dimensions without coordinates: epoch
Block Averages
[8]:
# calculate baseline
baseline = all_epochs.sel(reltime=(all_epochs.reltime < 0)).mean("reltime")
# subtract baseline
all_epochs_blcorrected = all_epochs - baseline
# group trials by trial_type. For each group individually average the epoch dimension
blockaverage = all_epochs_blcorrected.groupby("trial_type").mean("epoch")
Plotting averaged epochs
[9]:
f,ax = p.subplots(5,6, figsize=(12,10))
ax = ax.flatten()
for i_ch, ch in enumerate(blockaverage.channel):
for ls, trial_type in zip(["-", "--"], blockaverage.trial_type):
ax[i_ch].plot(blockaverage.reltime, blockaverage.sel(chromo="HbO", trial_type=trial_type, channel=ch), "r", lw=2, ls=ls)
ax[i_ch].plot(blockaverage.reltime, blockaverage.sel(chromo="HbR", trial_type=trial_type, channel=ch), "b", lw=2, ls=ls)
ax[i_ch].grid(1)
ax[i_ch].set_title(ch.values)
ax[i_ch].set_ylim(-.3, .6)
p.tight_layout()
Training a LDA classifier with Scikit-Learn
[10]:
# start with the frequency-filtered, epoched and baseline-corrected concentration data
# discard the samples before the stimulus onset
epochs = all_epochs_blcorrected.sel(reltime=all_epochs_blcorrected.reltime >=0)
# strip units. sklearn would strip them anyway and issue a warning about it.
epochs = epochs.pint.dequantify()
# need to manually tell xarray to create an index for trial_type
epochs = epochs.set_xindex("trial_type")
[11]:
display(epochs)
<xarray.DataArray 'cfepochs' (epoch: 300, chromo: 2, channel: 28, reltime: 157)> Size: 21MB array([[[[-3.6270e-02, -1.8065e-02, 2.0471e-03, ..., -6.8976e-02, -9.4304e-02, -1.2053e-01], [ 4.6475e-03, 1.1837e-02, 1.9097e-02, ..., -1.5186e-01, -1.5808e-01, -1.6264e-01], [ 2.4496e-03, 1.2294e-02, 2.3849e-02, ..., -1.3250e-01, -1.4658e-01, -1.6077e-01], ..., [-4.3450e-02, -3.0909e-02, -1.6201e-02, ..., -3.4983e-01, -3.6779e-01, -3.8482e-01], [ 2.2952e-02, 3.4524e-02, 4.6622e-02, ..., -2.6465e-01, -2.6958e-01, -2.7375e-01], [-4.2557e-03, 1.3953e-02, 3.3226e-02, ..., -2.5881e-01, -2.7617e-01, -2.9193e-01]], [[ 1.5914e-02, 1.3944e-02, 1.0925e-02, ..., -7.7108e-02, -7.4584e-02, -6.8965e-02], [-1.2916e-02, -9.4326e-03, -4.7285e-03, ..., -5.5005e-02, -6.1136e-02, -6.6217e-02], [-6.1391e-03, -2.1726e-03, -2.9567e-05, ..., -8.1372e-02, -8.0476e-02, -7.8297e-02], ... [ 1.5353e-01, 1.5630e-01, 1.5688e-01, ..., 6.7293e-02, 6.9011e-02, 7.1357e-02], [ 2.4956e-01, 2.4903e-01, 2.4447e-01, ..., 7.8465e-02, 8.7128e-02, 9.4109e-02], [ 4.8391e-01, 4.8248e-01, 4.7416e-01, ..., 3.1090e-01, 3.1703e-01, 3.2376e-01]], [[-1.5826e-02, -1.2935e-02, -5.2803e-03, ..., 2.3303e-03, -1.2774e-02, -2.7533e-02], [-1.3622e-02, -1.8221e-02, -2.3395e-02, ..., 1.1320e-02, 3.6992e-02, 5.3854e-02], [-3.5104e-02, -3.3747e-02, -3.9994e-02, ..., 7.3406e-02, 5.3683e-02, 3.3055e-02], ..., [ 1.1524e-02, 1.4339e-02, 1.6261e-02, ..., 9.8163e-02, 1.0049e-01, 1.0178e-01], [-5.6057e-02, -5.8393e-02, -5.9799e-02, ..., 6.0866e-02, 5.3214e-02, 4.5954e-02], [ 3.0087e-02, 3.2592e-02, 3.6472e-02, ..., -5.2315e-03, -6.2531e-03, -8.8259e-03]]]]) Coordinates: (3/7) * channel (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16' source (channel) object 224B 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8' ... ... * trial_type (epoch) object 2kB 'Tapping/Left' ... 'Tapping/Right' Dimensions without coordinates: epoch Attributes: units: micromolar
[12]:
X = epochs.stack(features=["chromo", "channel", "reltime"])
display(X)
<xarray.DataArray 'cfepochs' (epoch: 300, features: 8792)> Size: 21MB array([[-0.0363, -0.0181, 0.002 , ..., -0.0126, -0.0093, -0.0072], [ 0.109 , 0.0983, 0.0847, ..., -0.0343, -0.036 , -0.0378], [-0.0203, -0.0267, -0.0325, ..., -0.0187, -0.0155, -0.0115], ..., [ 0.029 , 0.0078, -0.0164, ..., -0.0117, -0.0013, 0.0085], [-0.2281, -0.2284, -0.2273, ..., 0.0104, 0.0108, 0.0135], [ 0.2505, 0.2608, 0.2643, ..., -0.0052, -0.0063, -0.0088]]) Coordinates: (3/8) source (features) object 70kB 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8' detector (features) object 70kB 'D1' 'D1' 'D1' 'D1' ... 'D16' 'D16' 'D16' ... ... * reltime (features) float64 70kB 0.0 0.128 0.256 ... 19.71 19.84 19.97 Dimensions without coordinates: epoch Attributes: units: micromolar
[13]:
y = xr.apply_ufunc(LabelEncoder().fit_transform, X.trial_type)
display(y)
<xarray.DataArray 'trial_type' (epoch: 300)> Size: 2kB array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) Coordinates: subject (epoch) <U6 7kB 'sub-01' 'sub-01' 'sub-01' ... 'sub-05' 'sub-05' * trial_type (epoch) object 2kB 'Tapping/Left' ... 'Tapping/Right' Dimensions without coordinates: epoch
[14]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3, stratify=y)
classifier = LinearDiscriminantAnalysis(n_components=1).fit(X_train, y_train)
y_pred = classifier.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
Accuracy: 0.8111111111111111
[15]:
f,ax = p.subplots(1,2, figsize=(12,3))
for trial_type, c in zip(["Tapping/Left", "Tapping/Right"], ["r", "g"]):
kw = dict(alpha=.5, fc=c, label=trial_type)
ax[0].hist(classifier.decision_function(X_train.sel(trial_type=trial_type)),**kw)
ax[1].hist(classifier.decision_function(X_test.sel(trial_type=trial_type)), **kw)
ax[0].set_xlabel("LDA score"); ax[1].set_xlabel("LDA score"); ax[0].set_title("train"); ax[1].set_title("test"); ax[0].legend(ncol=1,loc="upper left"); ax[1].legend(ncol=1, loc="upper left");
[ ]: