Basic single trial fNIRS finger tapping classification

This notebook sketches the analysis of a finger tapping dataset with multiple subjects. A simple Linear Discriminant Analysis (LDA) classifier is trained to distinguish left and right fingertapping.

[1]:
import cedalion
import cedalion.nirs
import cedalion.xrutils as xrutils
from cedalion.datasets import get_multisubject_fingertapping_snirf_paths
import numpy as np
import xarray as xr
import pint
import matplotlib.pyplot as p
import scipy.signal
import os.path

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split, cross_validate
from sklearn.metrics import accuracy_score

xr.set_options(display_max_rows=3, display_values_threshold=50)
np.set_printoptions(precision=4)

Loading raw CW-NIRS data from a SNIRF file

This notebook uses a finger-tapping dataset in BIDS layout provided by Rob Luke. It can can be downloaded via cedalion.datasets.

xarray provides another container: xr.DataSet. These are collections of xr.DataArray that share coordinate axes. These can be used to group different arrays with shared coordinate axes together.

[2]:
fnames = get_multisubject_fingertapping_snirf_paths()
subjects  = [f"sub-{i:02d}" for i in [1,2,3,4,5]]

# store data of different subjects in a dictionary
data = {}
for subject,fname in zip(subjects, fnames):

    elements = cedalion.io.read_snirf(fname)

    amp = elements[0].data[0]
    stim = elements[0].stim # pandas Dataframe
    geo3d = elements[0].geo3d

    # cedalion registers an accessor (attribute .cd ) on pandas DataFrames
    stim.cd.rename_events( {
        "1.0" : "control",
        "2.0" : "Tapping/Left",
        "3.0" : "Tapping/Right"
    })

    dpf = xr.DataArray([6, 6], dims="wavelength", coords={"wavelength" : amp.wavelength})

    data[subject] = xr.Dataset(
        data_vars = {
            "amp" : amp,
            "od"  : - np.log( amp / amp.mean("time") ),
            "geo" : geo3d,
            "conc": cedalion.nirs.beer_lambert(amp, geo3d, dpf)
        },
        attrs={"stim" : stim}, # store stimulus data in attrs
        coords={"subject" : subject} # add the subject label as a coordinate
    )
Downloading file 'multisubject-fingertapping.zip' from 'https://doc.ml.tu-berlin.de/cedalion/datasets/multisubject-fingertapping.zip' to '/home/runner/.cache/cedalion'.
Unzipping contents of '/home/runner/.cache/cedalion/multisubject-fingertapping.zip' to '/home/runner/.cache/cedalion/multisubject-fingertapping.zip.unzip'

Illustrate the dataset of one subject

[3]:
display(data["sub-01"])
<xarray.Dataset> Size: 32MB
Dimensions:     (time: 23239, channel: 28, wavelength: 2, label: 55, pos: 3,
                 chromo: 2)
Coordinates: (3/10)
  * time        (time) float64 186kB 0.0 0.128 0.256 ... 2.974e+03 2.974e+03
    samples     (time) int64 186kB 0 1 2 3 4 5 ... 23234 23235 23236 23237 23238
    ...          ...
    subject     <U6 24B 'sub-01'
Dimensions without coordinates: pos
Data variables: (3/4)
    amp         (channel, wavelength, time) float64 10MB [] 0.09137 ... 1.276
    od          (channel, wavelength, time) float64 10MB [] 0.04042 ... -0.01317
    ...          ...
    conc        (chromo, channel, time) float64 10MB [µM] 0.1336 ... -1.076
Attributes:
    stim:            onset  duration  value     trial_type\n0     61.824     ...

Frequency filtering and splitting into epochs

[4]:
for subject, ds in data.items():
    # cedalion registers the accessor .cd on DataArrays
    # to provide common functionality like frequency filters...
    ds["conc_freqfilt"] = ds["conc"].cd.freq_filter(fmin=0.02, fmax=0.5, butter_order=4)

    # ... or epoch splitting
    ds["cfepochs"] = ds["conc_freqfilt"].cd.to_epochs(
        ds.attrs["stim"], # stimulus dataframe
        ["Tapping/Left", "Tapping/Right"],  # select events
        before=5, # seconds before stimulus
        after=20  # seconds after stimulus
    )

Plot frequency filtered data

Illustrate for a single subject and channel the effect of the bandpass filter.

[5]:
ds = data["sub-01"]
channel = "S5D7"

f,ax= p.subplots(2,1, figsize=(12,4), sharex=True)
ax[0].plot(ds.time, ds.conc.sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[0].plot(ds.time, ds.conc.sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[1].plot(ds.time, ds.conc_freqfilt.sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[1].plot(ds.time, ds.conc_freqfilt.sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[0].set_xlim(1000,1200)
ax[1].set_xlabel("time / s")
ax[0].set_ylabel("$\Delta c$ / $\mu M$")
ax[1].set_ylabel("$\Delta c$ / $\mu M$")
ax[0].legend(loc="upper left"); ax[1].legend(loc="upper left");
../_images/examples_finger_tapping_lda_classification_10_0.png
[6]:
display(data["sub-01"]["cfepochs"])
<xarray.DataArray 'cfepochs' (epoch: 60, chromo: 2, channel: 28, reltime: 196)> Size: 5MB
<Quantity([[[[-1.9479e-02 -1.9950e-02 -2.0945e-02 ... -1.9397e-01 -2.1930e-01
    -2.4552e-01]
   [-1.0442e-02 -1.1375e-02 -1.2845e-02 ... -2.5376e-01 -2.5998e-01
    -2.6454e-01]
   [-1.2500e-02 -8.8897e-03 -5.8428e-03 ... -2.1148e-01 -2.2557e-01
    -2.3975e-01]
   ...
   [ 9.5644e-02  1.0075e-01  1.0499e-01 ... -3.2907e-01 -3.4702e-01
    -3.6405e-01]
   [ 2.0733e-02  2.2404e-02  2.3932e-02 ... -2.9462e-01 -2.9954e-01
    -3.0371e-01]
   [ 1.4324e-03  2.0414e-03  4.1026e-03 ... -3.0974e-01 -3.2710e-01
    -3.4287e-01]]

  [[ 1.4641e-02  6.4301e-03 -1.6854e-03 ... -6.5363e-02 -6.2839e-02
    -5.7220e-02]
   [ 2.0950e-03  5.5313e-03  8.9982e-03 ... -4.8858e-02 -5.4990e-02
    -6.0070e-02]
   [ 4.5053e-02  4.2005e-02  3.9300e-02 ... -4.3936e-02 -4.3039e-02
    -4.0861e-02]
...
   [ 2.8454e-01  2.9324e-01  3.0016e-01 ...  1.0617e-01  1.1339e-01
     1.2119e-01]
   [ 3.4306e-01  3.9592e-01  4.4809e-01 ... -1.8578e-01 -1.7872e-01
    -1.7142e-01]
   [ 6.1801e-01  6.1059e-01  6.0291e-01 ...  1.8868e-01  1.9003e-01
     1.9076e-01]]

  [[ 6.5167e-02  5.7419e-02  4.9134e-02 ...  3.1521e-02  3.2053e-02
     3.1760e-02]
   [ 3.6845e-02  3.4700e-02  3.1450e-02 ... -3.0447e-02 -2.7677e-02
    -2.3868e-02]
   [ 4.8888e-02  3.9980e-02  3.2743e-02 ...  3.0387e-02  3.4824e-02
     3.9630e-02]
   ...
   [ 5.8174e-02  5.3107e-02  4.7739e-02 ... -8.1363e-03 -6.2960e-03
    -5.3040e-03]
   [ 1.5154e-01  1.6202e-01  1.7299e-01 ...  2.6186e-02  3.2992e-02
     3.8859e-02]
   [ 1.5938e-01  1.5348e-01  1.4687e-01 ...  3.8222e-02  4.2833e-02
     4.6774e-02]]]], 'micromolar')>
Coordinates: (3/7)
  * channel     (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16'
    source      (channel) object 224B 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8'
    ...          ...
    trial_type  (epoch) object 480B 'Tapping/Left' ... 'Tapping/Right'
Dimensions without coordinates: epoch
[7]:
all_epochs = xr.concat([ds["cfepochs"] for ds in data.values()], dim="epoch")
all_epochs
[7]:
<xarray.DataArray 'cfepochs' (epoch: 300, chromo: 2, channel: 28, reltime: 196)> Size: 26MB
<Quantity([[[[-1.9479e-02 -1.9950e-02 -2.0945e-02 ... -1.9397e-01 -2.1930e-01
    -2.4552e-01]
   [-1.0442e-02 -1.1375e-02 -1.2845e-02 ... -2.5376e-01 -2.5998e-01
    -2.6454e-01]
   [-1.2500e-02 -8.8897e-03 -5.8428e-03 ... -2.1148e-01 -2.2557e-01
    -2.3975e-01]
   ...
   [ 9.5644e-02  1.0075e-01  1.0499e-01 ... -3.2907e-01 -3.4702e-01
    -3.6405e-01]
   [ 2.0733e-02  2.2404e-02  2.3932e-02 ... -2.9462e-01 -2.9954e-01
    -3.0371e-01]
   [ 1.4324e-03  2.0414e-03  4.1026e-03 ... -3.0974e-01 -3.2710e-01
    -3.4287e-01]]

  [[ 1.4641e-02  6.4301e-03 -1.6854e-03 ... -6.5363e-02 -6.2839e-02
    -5.7220e-02]
   [ 2.0950e-03  5.5313e-03  8.9982e-03 ... -4.8858e-02 -5.4990e-02
    -6.0070e-02]
   [ 4.5053e-02  4.2005e-02  3.9300e-02 ... -4.3936e-02 -4.3039e-02
    -4.0861e-02]
...
   [-1.8642e-01 -1.8383e-01 -1.8031e-01 ...  9.3582e-03  1.1076e-02
     1.3423e-02]
   [-2.8335e-01 -2.7513e-01 -2.6600e-01 ... -1.3275e-02 -4.6116e-03
     2.3698e-03]
   [-3.7102e-01 -3.6417e-01 -3.5820e-01 ...  1.8997e-01  1.9610e-01
     2.0283e-01]]

  [[ 6.9413e-02  6.7628e-02  6.5086e-02 ...  2.3354e-02  8.2494e-03
    -6.5091e-03]
   [ 9.3748e-02  7.1599e-02  4.8922e-02 ...  6.0442e-03  3.1716e-02
     4.8579e-02]
   [ 1.3260e-01  1.2931e-01  1.2489e-01 ...  1.8629e-01  1.6657e-01
     1.4594e-01]
   ...
   [ 3.5046e-02  3.1738e-02  2.8594e-02 ...  9.8597e-02  1.0093e-01
     1.0222e-01]
   [ 8.7634e-02  8.3598e-02  8.0126e-02 ...  1.1028e-01  1.0263e-01
     9.5370e-02]
   [-3.3306e-02 -3.6382e-02 -3.7580e-02 ...  4.1337e-03  3.1121e-03
     5.3935e-04]]]], 'micromolar')>
Coordinates: (3/7)
  * channel     (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16'
    source      (channel) object 224B 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8'
    ...          ...
    trial_type  (epoch) object 2kB 'Tapping/Left' ... 'Tapping/Right'
Dimensions without coordinates: epoch

Block Averages

[8]:
# calculate baseline
baseline = all_epochs.sel(reltime=(all_epochs.reltime < 0)).mean("reltime")
# subtract baseline
all_epochs_blcorrected = all_epochs - baseline

# group trials by trial_type. For each group individually average the epoch dimension
blockaverage = all_epochs_blcorrected.groupby("trial_type").mean("epoch")

Plotting averaged epochs

[9]:
f,ax = p.subplots(5,6, figsize=(12,10))
ax = ax.flatten()
for i_ch, ch in enumerate(blockaverage.channel):
    for ls, trial_type in zip(["-", "--"], blockaverage.trial_type):
        ax[i_ch].plot(blockaverage.reltime, blockaverage.sel(chromo="HbO", trial_type=trial_type, channel=ch), "r", lw=2, ls=ls)
        ax[i_ch].plot(blockaverage.reltime, blockaverage.sel(chromo="HbR", trial_type=trial_type, channel=ch), "b", lw=2, ls=ls)
    ax[i_ch].grid(1)
    ax[i_ch].set_title(ch.values)
    ax[i_ch].set_ylim(-.3, .6)

p.tight_layout()
../_images/examples_finger_tapping_lda_classification_16_0.png

Training a LDA classifier with Scikit-Learn

[10]:
# start with the frequency-filtered, epoched and baseline-corrected concentration data
# discard the samples before the stimulus onset
epochs = all_epochs_blcorrected.sel(reltime=all_epochs_blcorrected.reltime >=0)
# strip units. sklearn would strip them anyway and issue a warning about it.
epochs = epochs.pint.dequantify()

# need to manually tell xarray to create an index for trial_type
epochs = epochs.set_xindex("trial_type")
[11]:
display(epochs)
<xarray.DataArray 'cfepochs' (epoch: 300, chromo: 2, channel: 28, reltime: 157)> Size: 21MB
array([[[[-3.6270e-02, -1.8065e-02,  2.0471e-03, ..., -6.8976e-02,
          -9.4304e-02, -1.2053e-01],
         [ 4.6475e-03,  1.1837e-02,  1.9097e-02, ..., -1.5186e-01,
          -1.5808e-01, -1.6264e-01],
         [ 2.4496e-03,  1.2294e-02,  2.3849e-02, ..., -1.3250e-01,
          -1.4658e-01, -1.6077e-01],
         ...,
         [-4.3450e-02, -3.0909e-02, -1.6201e-02, ..., -3.4983e-01,
          -3.6779e-01, -3.8482e-01],
         [ 2.2952e-02,  3.4524e-02,  4.6622e-02, ..., -2.6465e-01,
          -2.6958e-01, -2.7375e-01],
         [-4.2557e-03,  1.3953e-02,  3.3226e-02, ..., -2.5881e-01,
          -2.7617e-01, -2.9193e-01]],

        [[ 1.5914e-02,  1.3944e-02,  1.0925e-02, ..., -7.7108e-02,
          -7.4584e-02, -6.8965e-02],
         [-1.2916e-02, -9.4326e-03, -4.7285e-03, ..., -5.5005e-02,
          -6.1136e-02, -6.6217e-02],
         [-6.1391e-03, -2.1726e-03, -2.9567e-05, ..., -8.1372e-02,
          -8.0476e-02, -7.8297e-02],
...
         [ 1.5353e-01,  1.5630e-01,  1.5688e-01, ...,  6.7293e-02,
           6.9011e-02,  7.1357e-02],
         [ 2.4956e-01,  2.4903e-01,  2.4447e-01, ...,  7.8465e-02,
           8.7128e-02,  9.4109e-02],
         [ 4.8391e-01,  4.8248e-01,  4.7416e-01, ...,  3.1090e-01,
           3.1703e-01,  3.2376e-01]],

        [[-1.5826e-02, -1.2935e-02, -5.2803e-03, ...,  2.3303e-03,
          -1.2774e-02, -2.7533e-02],
         [-1.3622e-02, -1.8221e-02, -2.3395e-02, ...,  1.1320e-02,
           3.6992e-02,  5.3854e-02],
         [-3.5104e-02, -3.3747e-02, -3.9994e-02, ...,  7.3406e-02,
           5.3683e-02,  3.3055e-02],
         ...,
         [ 1.1524e-02,  1.4339e-02,  1.6261e-02, ...,  9.8163e-02,
           1.0049e-01,  1.0178e-01],
         [-5.6057e-02, -5.8393e-02, -5.9799e-02, ...,  6.0866e-02,
           5.3214e-02,  4.5954e-02],
         [ 3.0087e-02,  3.2592e-02,  3.6472e-02, ..., -5.2315e-03,
          -6.2531e-03, -8.8259e-03]]]])
Coordinates: (3/7)
  * channel     (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16'
    source      (channel) object 224B 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8'
    ...          ...
  * trial_type  (epoch) object 2kB 'Tapping/Left' ... 'Tapping/Right'
Dimensions without coordinates: epoch
Attributes:
    units:    micromolar
[12]:
X = epochs.stack(features=["chromo", "channel", "reltime"])
display(X)
<xarray.DataArray 'cfepochs' (epoch: 300, features: 8792)> Size: 21MB
array([[-0.0363, -0.0181,  0.002 , ..., -0.0126, -0.0093, -0.0072],
       [ 0.109 ,  0.0983,  0.0847, ..., -0.0343, -0.036 , -0.0378],
       [-0.0203, -0.0267, -0.0325, ..., -0.0187, -0.0155, -0.0115],
       ...,
       [ 0.029 ,  0.0078, -0.0164, ..., -0.0117, -0.0013,  0.0085],
       [-0.2281, -0.2284, -0.2273, ...,  0.0104,  0.0108,  0.0135],
       [ 0.2505,  0.2608,  0.2643, ..., -0.0052, -0.0063, -0.0088]])
Coordinates: (3/8)
    source      (features) object 70kB 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8'
    detector    (features) object 70kB 'D1' 'D1' 'D1' 'D1' ... 'D16' 'D16' 'D16'
    ...          ...
  * reltime     (features) float64 70kB 0.0 0.128 0.256 ... 19.71 19.84 19.97
Dimensions without coordinates: epoch
Attributes:
    units:    micromolar
[13]:
y = xr.apply_ufunc(LabelEncoder().fit_transform, X.trial_type)
display(y)
<xarray.DataArray 'trial_type' (epoch: 300)> Size: 2kB
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
Coordinates:
    subject     (epoch) <U6 7kB 'sub-01' 'sub-01' 'sub-01' ... 'sub-05' 'sub-05'
  * trial_type  (epoch) object 2kB 'Tapping/Left' ... 'Tapping/Right'
Dimensions without coordinates: epoch
[14]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3, stratify=y)
classifier = LinearDiscriminantAnalysis(n_components=1).fit(X_train, y_train)
y_pred = classifier.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
Accuracy: 0.8111111111111111
[15]:
f,ax = p.subplots(1,2, figsize=(12,3))
for trial_type, c in zip(["Tapping/Left", "Tapping/Right"], ["r", "g"]):
    kw  = dict(alpha=.5, fc=c, label=trial_type)
    ax[0].hist(classifier.decision_function(X_train.sel(trial_type=trial_type)),**kw)
    ax[1].hist(classifier.decision_function(X_test.sel(trial_type=trial_type)), **kw)

ax[0].set_xlabel("LDA score"); ax[1].set_xlabel("LDA score"); ax[0].set_title("train"); ax[1].set_title("test"); ax[0].legend(ncol=1,loc="upper left"); ax[1].legend(ncol=1, loc="upper left");
../_images/examples_finger_tapping_lda_classification_23_0.png
[ ]: