Basic single trial fNIRS finger tapping classification
This notebook sketches the analysis of a finger tapping dataset with multiple subjects. A simple Linear Discriminant Analysis (LDA) classifier is trained to distinguish left and right fingertapping.
[1]:
import cedalion
import cedalion.nirs
from cedalion.datasets import get_multisubject_fingertapping_snirf_paths
import numpy as np
import xarray as xr
import matplotlib.pyplot as p
from sklearn.preprocessing import LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
xr.set_options(display_max_rows=3, display_values_threshold=50)
np.set_printoptions(precision=4)
Loading raw CW-NIRS data from a SNIRF file
This notebook uses a finger-tapping dataset in BIDS layout provided by Rob Luke. It can can be downloaded via cedalion.datasets.
Cedalion’s read_snirf method returns a list of Recording objects. These are containers for timeseries and adjunct data objects.
[2]:
fnames = get_multisubject_fingertapping_snirf_paths()
subjects = [f"sub-{i:02d}" for i in [1, 2, 3, 4, 5]]
# store data of different subjects in a dictionary
data = {}
for subject, fname in zip(subjects, fnames):
records = cedalion.io.read_snirf(fname)
rec = records[0]
display(rec)
# Cedalion registers an accessor (attribute .cd ) on pandas DataFrames.
# Use this to rename trial_types inplace.
rec.stim.cd.rename_events(
{"1.0": "control", "2.0": "Tapping/Left", "3.0": "Tapping/Right"}
)
dpf = xr.DataArray(
[6, 6],
dims="wavelength",
coords={"wavelength": rec["amp"].wavelength},
)
rec["od"] = -np.log(rec["amp"] / rec["amp"].mean("time")),
rec["conc"] = cedalion.nirs.beer_lambert(rec["amp"], rec.geo3d, dpf)
data[subject] = rec
Downloading file 'multisubject-fingertapping.zip' from 'https://doc.ml.tu-berlin.de/cedalion/datasets/multisubject-fingertapping.zip' to '/home/runner/.cache/cedalion'.
Unzipping contents of '/home/runner/.cache/cedalion/multisubject-fingertapping.zip' to '/home/runner/.cache/cedalion/multisubject-fingertapping.zip.unzip'
<Recording | timeseries: ['amp'], masks: [], stim: ['1.0', '15.0', '2.0', '3.0'], aux_ts: [], aux_obj: []>
<Recording | timeseries: ['amp'], masks: [], stim: ['1.0', '15.0', '2.0', '3.0'], aux_ts: [], aux_obj: []>
<Recording | timeseries: ['amp'], masks: [], stim: ['1.0', '15.0', '2.0', '3.0'], aux_ts: [], aux_obj: []>
<Recording | timeseries: ['amp'], masks: [], stim: ['1.0', '15.0', '2.0', '3.0'], aux_ts: [], aux_obj: []>
<Recording | timeseries: ['amp'], masks: [], stim: ['1.0', '15.0', '2.0', '3.0'], aux_ts: [], aux_obj: []>
Illustrate the dataset of one subject
[3]:
display(data["sub-01"])
<Recording | timeseries: ['amp', 'od', 'conc'], masks: [], stim: ['control', '15.0', 'Tapping/Left', 'Tapping/Right'], aux_ts: [], aux_obj: []>
Frequency filtering and splitting into epochs
[4]:
for subject, rec in data.items():
# cedalion registers the accessor .cd on DataArrays
# to provide common functionality like frequency filters...
rec["conc_freqfilt"] = rec["conc"].cd.freq_filter(
fmin=0.02, fmax=0.5, butter_order=4
)
# ... or epoch splitting
rec["cfepochs"] = rec["conc_freqfilt"].cd.to_epochs(
rec.stim, # stimulus dataframe
["Tapping/Left", "Tapping/Right"], # select events
before=5, # seconds before stimulus
after=20, # seconds after stimulus
)
Plot frequency filtered data
Illustrate for a single subject and channel the effect of the bandpass filter.
[5]:
rec = data["sub-01"]
channel = "S5D7"
f, ax = p.subplots(2, 1, figsize=(12, 4), sharex=True)
ax[0].plot(rec["conc"].time, rec["conc"].sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[0].plot(rec["conc"].time, rec["conc"].sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[1].plot(
rec["conc_freqfilt"].time,
rec["conc_freqfilt"].sel(channel=channel, chromo="HbO"),
"r-",
label="HbO",
)
ax[1].plot(
rec["conc_freqfilt"].time,
rec["conc_freqfilt"].sel(channel=channel, chromo="HbR"),
"b-",
label="HbR",
)
ax[0].set_xlim(1000, 1200)
ax[1].set_xlabel("time / s")
ax[0].set_ylabel("$\Delta c$ / $\mu M$")
ax[1].set_ylabel("$\Delta c$ / $\mu M$")
ax[0].legend(loc="upper left")
ax[1].legend(loc="upper left");
[6]:
display(data["sub-01"]["cfepochs"])
<xarray.DataArray 'concentration' (epoch: 60, chromo: 2, channel: 28,
reltime: 196)> Size: 5MB
<Quantity([[[[-1.9479e-02 -1.9950e-02 -2.0945e-02 ... -1.9397e-01 -2.1930e-01
-2.4552e-01]
[-1.0442e-02 -1.1375e-02 -1.2845e-02 ... -2.5376e-01 -2.5998e-01
-2.6454e-01]
[-1.2500e-02 -8.8897e-03 -5.8428e-03 ... -2.1148e-01 -2.2557e-01
-2.3975e-01]
...
[ 9.5644e-02 1.0075e-01 1.0499e-01 ... -3.2907e-01 -3.4702e-01
-3.6405e-01]
[ 2.0733e-02 2.2404e-02 2.3932e-02 ... -2.9462e-01 -2.9954e-01
-3.0371e-01]
[ 1.4324e-03 2.0415e-03 4.1027e-03 ... -3.0974e-01 -3.2710e-01
-3.4287e-01]]
[[ 1.4641e-02 6.4301e-03 -1.6853e-03 ... -6.5363e-02 -6.2839e-02
-5.7220e-02]
[ 2.0950e-03 5.5313e-03 8.9983e-03 ... -4.8858e-02 -5.4990e-02
-6.0070e-02]
[ 4.5053e-02 4.2005e-02 3.9300e-02 ... -4.3936e-02 -4.3039e-02
-4.0861e-02]
...
[ 2.8454e-01 2.9324e-01 3.0016e-01 ... 1.0617e-01 1.1339e-01
1.2119e-01]
[ 3.4306e-01 3.9592e-01 4.4809e-01 ... -1.8578e-01 -1.7872e-01
-1.7142e-01]
[ 6.1801e-01 6.1059e-01 6.0291e-01 ... 1.8868e-01 1.9003e-01
1.9076e-01]]
[[ 6.5167e-02 5.7419e-02 4.9134e-02 ... 3.1521e-02 3.2053e-02
3.1760e-02]
[ 3.6845e-02 3.4700e-02 3.1450e-02 ... -3.0447e-02 -2.7677e-02
-2.3868e-02]
[ 4.8888e-02 3.9980e-02 3.2743e-02 ... 3.0387e-02 3.4824e-02
3.9630e-02]
...
[ 5.8174e-02 5.3107e-02 4.7739e-02 ... -8.1364e-03 -6.2961e-03
-5.3040e-03]
[ 1.5154e-01 1.6202e-01 1.7299e-01 ... 2.6186e-02 3.2992e-02
3.8859e-02]
[ 1.5938e-01 1.5348e-01 1.4687e-01 ... 3.8222e-02 4.2833e-02
4.6774e-02]]]], 'micromolar')>
Coordinates: (3/6)
* chromo (chromo) <U3 24B 'HbO' 'HbR'
* channel (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16'
... ...
trial_type (epoch) object 480B 'Tapping/Left' ... 'Tapping/Right'
Dimensions without coordinates: epoch[7]:
all_epochs = xr.concat([rec["cfepochs"] for rec in data.values()], dim="epoch")
all_epochs
[7]:
<xarray.DataArray 'concentration' (epoch: 300, chromo: 2, channel: 28,
reltime: 196)> Size: 26MB
<Quantity([[[[-1.9479e-02 -1.9950e-02 -2.0945e-02 ... -1.9397e-01 -2.1930e-01
-2.4552e-01]
[-1.0442e-02 -1.1375e-02 -1.2845e-02 ... -2.5376e-01 -2.5998e-01
-2.6454e-01]
[-1.2500e-02 -8.8897e-03 -5.8428e-03 ... -2.1148e-01 -2.2557e-01
-2.3975e-01]
...
[ 9.5644e-02 1.0075e-01 1.0499e-01 ... -3.2907e-01 -3.4702e-01
-3.6405e-01]
[ 2.0733e-02 2.2404e-02 2.3932e-02 ... -2.9462e-01 -2.9954e-01
-3.0371e-01]
[ 1.4324e-03 2.0415e-03 4.1027e-03 ... -3.0974e-01 -3.2710e-01
-3.4287e-01]]
[[ 1.4641e-02 6.4301e-03 -1.6853e-03 ... -6.5363e-02 -6.2839e-02
-5.7220e-02]
[ 2.0950e-03 5.5313e-03 8.9983e-03 ... -4.8858e-02 -5.4990e-02
-6.0070e-02]
[ 4.5053e-02 4.2005e-02 3.9300e-02 ... -4.3936e-02 -4.3039e-02
-4.0861e-02]
...
[-1.8642e-01 -1.8383e-01 -1.8031e-01 ... 9.3582e-03 1.1076e-02
1.3423e-02]
[-2.8335e-01 -2.7513e-01 -2.6600e-01 ... -1.3274e-02 -4.6116e-03
2.3699e-03]
[-3.7102e-01 -3.6417e-01 -3.5820e-01 ... 1.8997e-01 1.9610e-01
2.0283e-01]]
[[ 6.9413e-02 6.7628e-02 6.5086e-02 ... 2.3354e-02 8.2493e-03
-6.5091e-03]
[ 9.3748e-02 7.1599e-02 4.8922e-02 ... 6.0441e-03 3.1716e-02
4.8579e-02]
[ 1.3260e-01 1.2931e-01 1.2489e-01 ... 1.8629e-01 1.6657e-01
1.4594e-01]
...
[ 3.5046e-02 3.1738e-02 2.8594e-02 ... 9.8597e-02 1.0093e-01
1.0222e-01]
[ 8.7634e-02 8.3598e-02 8.0126e-02 ... 1.1028e-01 1.0263e-01
9.5370e-02]
[-3.3306e-02 -3.6382e-02 -3.7580e-02 ... 4.1337e-03 3.1121e-03
5.3934e-04]]]], 'micromolar')>
Coordinates: (3/6)
* chromo (chromo) <U3 24B 'HbO' 'HbR'
* channel (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16'
... ...
trial_type (epoch) object 2kB 'Tapping/Left' ... 'Tapping/Right'
Dimensions without coordinates: epochBlock Averages
[8]:
# calculate baseline
baseline = all_epochs.sel(reltime=(all_epochs.reltime < 0)).mean("reltime")
# subtract baseline
all_epochs_blcorrected = all_epochs - baseline
# group trials by trial_type. For each group individually average the epoch dimension
blockaverage = all_epochs_blcorrected.groupby("trial_type").mean("epoch")
Plotting averaged epochs
[9]:
f, ax = p.subplots(5, 6, figsize=(12, 10))
ax = ax.flatten()
for i_ch, ch in enumerate(blockaverage.channel):
for ls, trial_type in zip(["-", "--"], blockaverage.trial_type):
ax[i_ch].plot(
blockaverage.reltime,
blockaverage.sel(chromo="HbO", trial_type=trial_type, channel=ch),
"r",
lw=2,
ls=ls,
)
ax[i_ch].plot(
blockaverage.reltime,
blockaverage.sel(chromo="HbR", trial_type=trial_type, channel=ch),
"b",
lw=2,
ls=ls,
)
ax[i_ch].grid(1)
ax[i_ch].set_title(ch.values)
ax[i_ch].set_ylim(-0.3, 0.6)
p.tight_layout()
Training a LDA classifier with Scikit-Learn
[10]:
# start with the frequency-filtered, epoched and baseline-corrected concentration data
# discard the samples before the stimulus onset
epochs = all_epochs_blcorrected.sel(reltime=all_epochs_blcorrected.reltime >=0)
# strip units. sklearn would strip them anyway and issue a warning about it.
epochs = epochs.pint.dequantify()
# need to manually tell xarray to create an index for trial_type
epochs = epochs.set_xindex("trial_type")
[11]:
display(epochs)
<xarray.DataArray 'concentration' (epoch: 300, chromo: 2, channel: 28,
reltime: 157)> Size: 21MB
array([[[[-3.6270e-02, -1.8064e-02, 2.0471e-03, ..., -6.8976e-02,
-9.4304e-02, -1.2053e-01],
[ 4.6475e-03, 1.1837e-02, 1.9097e-02, ..., -1.5186e-01,
-1.5808e-01, -1.6264e-01],
[ 2.4496e-03, 1.2294e-02, 2.3849e-02, ..., -1.3250e-01,
-1.4658e-01, -1.6077e-01],
...,
[-4.3450e-02, -3.0909e-02, -1.6201e-02, ..., -3.4983e-01,
-3.6779e-01, -3.8482e-01],
[ 2.2952e-02, 3.4524e-02, 4.6622e-02, ..., -2.6465e-01,
-2.6958e-01, -2.7375e-01],
[-4.2557e-03, 1.3953e-02, 3.3226e-02, ..., -2.5881e-01,
-2.7617e-01, -2.9194e-01]],
[[ 1.5914e-02, 1.3944e-02, 1.0925e-02, ..., -7.7108e-02,
-7.4584e-02, -6.8965e-02],
[-1.2916e-02, -9.4326e-03, -4.7285e-03, ..., -5.5005e-02,
-6.1136e-02, -6.6217e-02],
[-6.1391e-03, -2.1726e-03, -2.9566e-05, ..., -8.1372e-02,
-8.0476e-02, -7.8297e-02],
...
[ 1.5353e-01, 1.5630e-01, 1.5688e-01, ..., 6.7293e-02,
6.9011e-02, 7.1357e-02],
[ 2.4956e-01, 2.4903e-01, 2.4447e-01, ..., 7.8465e-02,
8.7128e-02, 9.4109e-02],
[ 4.8391e-01, 4.8248e-01, 4.7416e-01, ..., 3.1090e-01,
3.1703e-01, 3.2376e-01]],
[[-1.5826e-02, -1.2935e-02, -5.2803e-03, ..., 2.3303e-03,
-1.2774e-02, -2.7533e-02],
[-1.3622e-02, -1.8221e-02, -2.3395e-02, ..., 1.1320e-02,
3.6992e-02, 5.3854e-02],
[-3.5104e-02, -3.3747e-02, -3.9994e-02, ..., 7.3406e-02,
5.3683e-02, 3.3055e-02],
...,
[ 1.1524e-02, 1.4339e-02, 1.6261e-02, ..., 9.8163e-02,
1.0049e-01, 1.0178e-01],
[-5.6057e-02, -5.8393e-02, -5.9799e-02, ..., 6.0866e-02,
5.3214e-02, 4.5954e-02],
[ 3.0087e-02, 3.2592e-02, 3.6472e-02, ..., -5.2316e-03,
-6.2531e-03, -8.8259e-03]]]])
Coordinates: (3/6)
* chromo (chromo) <U3 24B 'HbO' 'HbR'
* channel (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16'
... ...
* trial_type (epoch) object 2kB 'Tapping/Left' ... 'Tapping/Right'
Dimensions without coordinates: epoch
Attributes:
units: micromolar[12]:
X = epochs.stack(features=["chromo", "channel", "reltime"])
display(X)
<xarray.DataArray 'concentration' (epoch: 300, features: 8792)> Size: 21MB
array([[-0.0363, -0.0181, 0.002 , ..., -0.0126, -0.0093, -0.0072],
[ 0.109 , 0.0983, 0.0847, ..., -0.0343, -0.036 , -0.0378],
[-0.0203, -0.0267, -0.0325, ..., -0.0187, -0.0155, -0.0115],
...,
[ 0.029 , 0.0078, -0.0164, ..., -0.0117, -0.0013, 0.0085],
[-0.2281, -0.2284, -0.2273, ..., 0.0104, 0.0108, 0.0135],
[ 0.2505, 0.2608, 0.2643, ..., -0.0052, -0.0063, -0.0088]])
Coordinates: (3/7)
source (features) object 70kB 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8'
detector (features) object 70kB 'D1' 'D1' 'D1' 'D1' ... 'D16' 'D16' 'D16'
... ...
* reltime (features) float64 70kB 0.0 0.128 0.256 ... 19.71 19.84 19.97
Dimensions without coordinates: epoch
Attributes:
units: micromolar[13]:
y = xr.apply_ufunc(LabelEncoder().fit_transform, X.trial_type)
display(y)
<xarray.DataArray 'trial_type' (epoch: 300)> Size: 2kB
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
Coordinates:
* trial_type (epoch) object 2kB 'Tapping/Left' ... 'Tapping/Right'
Dimensions without coordinates: epoch[14]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y)
classifier = LinearDiscriminantAnalysis(n_components=1).fit(X_train, y_train)
y_pred = classifier.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
Accuracy: 0.8888888888888888
[15]:
f, ax = p.subplots(1, 2, figsize=(12, 3))
for trial_type, c in zip(["Tapping/Left", "Tapping/Right"], ["r", "g"]):
kw = dict(alpha=0.5, fc=c, label=trial_type)
ax[0].hist(classifier.decision_function(X_train.sel(trial_type=trial_type)), **kw)
ax[1].hist(classifier.decision_function(X_test.sel(trial_type=trial_type)), **kw)
ax[0].set_xlabel("LDA score")
ax[1].set_xlabel("LDA score")
ax[0].set_title("train")
ax[1].set_title("test")
ax[0].legend(ncol=1, loc="upper left")
ax[1].legend(ncol=1, loc="upper left");