Sleep stage classification from polysomnography (PSG) data

Note

This code is taken from the analysis code used in 3. If you reuse this code please consider citing this work.

This tutorial explains how to perform a toy polysomnography analysis that answers the following question:

Important

Given two subjects from the Sleep Physionet dataset 1 2, namely Alice and Bob, how well can we predict the sleep stages of Bob from Alice’s data?

This problem is tackled as supervised multiclass classification task. The aim is to predict the sleep stage from 5 possible stages for each chunk of 30 seconds of data.

# Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
#          Stanislas Chambon <stan.chambon@gmail.com>
#          Joan Massich <mailsik@gmail.com>
#
# License: BSD Style.

import numpy as np
import matplotlib.pyplot as plt

import mne
from mne.datasets.sleep_physionet.age import fetch_data
from mne.time_frequency import psd_welch

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer

Load the data

Here we download the data from two subjects and the end goal is to obtain epochs and its associated ground truth.

MNE-Python provides us with mne.datasets.sleep_physionet.age.fetch_data() to conveniently download data from the Sleep Physionet dataset 1 2. Given a list of subjects and records, the fetcher downloads the data and provides us for each subject, a pair of files:

  • -PSG.edf containing the polysomnography. The raw data from the EEG helmet,

  • -Hypnogram.edf containing the annotations recorded by an expert.

Combining these two in a mne.io.Raw object then we can extract events based on the descriptions of the annotations to obtain the epochs.

Read the PSG data and Hypnograms to create a raw object

ALICE, BOB = 0, 1

[alice_files, bob_files] = fetch_data(subjects=[ALICE, BOB], recording=[1])

mapping = {'EOG horizontal': 'eog',
           'Resp oro-nasal': 'misc',
           'EMG submental': 'misc',
           'Temp rectal': 'misc',
           'Event marker': 'misc'}

raw_train = mne.io.read_raw_edf(alice_files[0])
annot_train = mne.read_annotations(alice_files[1])

raw_train.set_annotations(annot_train, emit_warning=False)
raw_train.set_channel_types(mapping)

# plot some data
raw_train.plot(duration=60, scalings='auto')
../../_images/sphx_glr_plot_sleep_001.png

Out:

Using default location ~/mne_data for PHYSIONET_SLEEP...
Extracting EDF parameters from /home/circleci/mne_data/physionet-sleep-data/SC4001E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...

Extract 30s events from annotations

The Sleep Physionet dataset is annotated using 8 labels: Wake (W), Stage 1, Stage 2, Stage 3, Stage 4 corresponding to the range from light sleep to deep sleep, REM sleep (R) where REM is the abbreviation for Rapid Eye Movement sleep, movement (M), and Stage (?) for any none scored segment.

We will work only with 5 stages: Wake (W), Stage 1, Stage 2, Stage 3/4, and REM sleep (R). To do so, we use the event_id parameter in mne.events_from_annotations() to select which events are we interested in and we associate an event identifier to each of them.

annotation_desc_2_event_id = {'Sleep stage W': 1,
                              'Sleep stage 1': 2,
                              'Sleep stage 2': 3,
                              'Sleep stage 3': 4,
                              'Sleep stage 4': 4,
                              'Sleep stage R': 5}

events_train, _ = mne.events_from_annotations(
    raw_train, event_id=annotation_desc_2_event_id, chunk_duration=30.)

# create a new event_id that unifies stages 3 and 4
event_id = {'Sleep stage W': 1,
            'Sleep stage 1': 2,
            'Sleep stage 2': 3,
            'Sleep stage 3/4': 4,
            'Sleep stage R': 5}

# plot events
mne.viz.plot_events(events_train, event_id=event_id,
                    sfreq=raw_train.info['sfreq'])

# keep the color-code for further plotting
stage_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
../../_images/sphx_glr_plot_sleep_002.png

Out:

Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']

Create Epochs from the data based on the events found in the annotations

tmax = 30. - 1. / raw_train.info['sfreq']  # tmax in included

epochs_train = mne.Epochs(raw=raw_train, events=events_train,
                          event_id=event_id, tmin=0., tmax=tmax, baseline=None)

print(epochs_train)

Out:

2650 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
<Epochs  |   2650 events (good & bad), 0 - 29.99 sec, baseline off, ~20 kB, data not loaded,
 'Sleep stage 1': 58
 'Sleep stage 2': 250
 'Sleep stage 3/4': 220
 'Sleep stage R': 125
 'Sleep stage W': 1997>

Applying the same steps to the test data from Bob

raw_test = mne.io.read_raw_edf(bob_files[0])
annot_test = mne.read_annotations(bob_files[1])
raw_test.set_annotations(annot_test, emit_warning=False)
raw_test.set_channel_types(mapping)
events_test, _ = mne.events_from_annotations(
    raw_test, event_id=annotation_desc_2_event_id, chunk_duration=30.)
epochs_test = mne.Epochs(raw=raw_test, events=events_test, event_id=event_id,
                         tmin=0., tmax=tmax, baseline=None)

print(epochs_test)

Out:

Extracting EDF parameters from /home/circleci/mne_data/physionet-sleep-data/SC4011E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
2802 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
<Epochs  |   2802 events (good & bad), 0 - 29.99 sec, baseline off, ~20 kB, data not loaded,
 'Sleep stage 1': 109
 'Sleep stage 2': 562
 'Sleep stage 3/4': 105
 'Sleep stage R': 170
 'Sleep stage W': 1856>

Feature Engineering

Observing the power spectrum density (PSD) plot of the epochs grouped by sleeping stage we can see that different sleep stages have different signatures. These signatures remain similar between Alice and Bob’s data.

The rest of this section we will create EEG features based on relative power in specific frequency bands to capture this difference between the sleep stages in our data.

# visualize Alice vs. Bob PSD by sleep stage.
fig, (ax1, ax2) = plt.subplots(ncols=2)

# iterate over the subjects
stages = sorted(event_id.keys())
for ax, title, epochs in zip([ax1, ax2],
                             ['Alice', 'Bob'],
                             [epochs_train, epochs_test]):

    for stage, color in zip(stages, stage_colors):
        epochs[stage].plot_psd(area_mode=None, color=color, ax=ax,
                               fmin=0.1, fmax=20.)
    ax.set(title=title, xlabel='Frequency (Hz)')
ax2.set(ylabel='uV^2/hz (dB)')
ax2.legend(stages)
../../_images/sphx_glr_plot_sleep_003.png

Out:

Loading data for 58 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Loading data for 250 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Loading data for 220 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Loading data for 125 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Loading data for 1997 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Loading data for 109 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Loading data for 562 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Loading data for 105 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Loading data for 170 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Loading data for 1856 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows

Design a scikit-learn transformer from a Python function

We will now create a function to extract EEG features based on relative power in specific frequency bands to be able to predict sleep stages from EEG signals.

def eeg_power_band(epochs):
    """EEG relative power band feature extraction.

    This function takes an ``mne.Epochs`` object and creates EEG features based
    on relative power in specific frequency bands that are compatible with
    scikit-learn.

    Parameters
    ----------
    epochs : Epochs
        The data.

    Returns
    -------
    X : numpy array of shape [n_samples, 5]
        Transformed data.
    """
    # specific frequency bands
    FREQ_BANDS = {"delta": [0.5, 4.5],
                  "theta": [4.5, 8.5],
                  "alpha": [8.5, 11.5],
                  "sigma": [11.5, 15.5],
                  "beta": [15.5, 30]}

    psds, freqs = psd_welch(epochs, picks='eeg', fmin=0.5, fmax=30.)
    # Normalize the PSDs
    psds /= np.sum(psds, axis=-1, keepdims=True)

    X = []
    for fmin, fmax in FREQ_BANDS.values():
        psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
        X.append(psds_band.reshape(len(psds), -1))

    return np.concatenate(X, axis=1)

Multiclass classification workflow using scikit-learn

To answer the question of how well can we predict the sleep stages of Bob from Alice’s data and avoid as much boilerplate code as possible, we will take advantage of two key features of sckit-learn: Pipeline , and FunctionTransformer.

Scikit-learn pipeline composes an estimator as a sequence of transforms and a final estimator, while the FunctionTransformer converts a python function in an estimator compatible object. In this manner we can create scikit-learn estimator that takes mne.Epochs thanks to eeg_power_band function we just created.

pipe = make_pipeline(FunctionTransformer(eeg_power_band, validate=False),
                     RandomForestClassifier(n_estimators=100, random_state=42))

# Train
y_train = epochs_train.events[:, 2]
pipe.fit(epochs_train, y_train)

# Test
y_pred = pipe.predict(epochs_test)

# Assess the results
y_test = epochs_test.events[:, 2]
acc = accuracy_score(y_test, y_pred)

print("Accuracy score: {}".format(acc))

Out:

Loading data for 2650 events and 3000 original time points ...
0 bad epochs dropped
Effective window size : 2.560 (s)
Loading data for 2802 events and 3000 original time points ...
0 bad epochs dropped
Effective window size : 2.560 (s)
Accuracy score: 0.8468950749464668

In short, yes. We can predict Bob’s sleeping stages based on Alice’s data.

Further analysis of the data

We can check the confusion matrix or the classification report.

print(confusion_matrix(y_test, y_pred))

Out:

[[1855    0    0    1    0]
 [  95    2    4    3    5]
 [ 121   16  383   30   12]
 [   0    0    4  101    0]
 [  89   30   19    0   32]]
print(classification_report(y_test, y_pred, target_names=event_id.keys()))

Out:

                 precision    recall  f1-score   support

  Sleep stage W       0.86      1.00      0.92      1856
  Sleep stage 1       0.04      0.02      0.03       109
  Sleep stage 2       0.93      0.68      0.79       562
Sleep stage 3/4       0.75      0.96      0.84       105
  Sleep stage R       0.65      0.19      0.29       170

       accuracy                           0.85      2802
      macro avg       0.65      0.57      0.57      2802
   weighted avg       0.83      0.85      0.82      2802

Exercise

Fetch 50 subjects from the Physionet database and run a 5-fold cross-validation leaving each time 10 subjects out in the test set.

References

1(1,2)

B Kemp, AH Zwinderman, B Tuk, HAC Kamphuisen, JJL Oberyé. Analysis of a sleep-dependent neuronal feedback loop: the slow-wave microcontinuity of the EEG. IEEE-BME 47(9):1185-1194 (2000).

2(1,2)

Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220

3

Chambon, S., Galtier, M., Arnal, P., Wainrib, G. and Gramfort, A. (2018)A Deep Learning Architecture for Temporal Sleep Stage Classification Using Multivariate and Multimodal Time Series. IEEE Trans. on Neural Systems and Rehabilitation Engineering 26: (758-769).

Total running time of the script: ( 0 minutes 33.297 seconds)

Estimated memory usage: 667 MB

Gallery generated by Sphinx-Gallery