Last active
April 15, 2025 11:43
-
-
Save emadeldeen24/510e24701dc7a484fd72880c787f7d8c to your computer and use it in GitHub Desktop.
Extract sleep stages from MESA dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import numpy as np | |
from mne.io import read_raw_edf | |
import warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
import logging | |
logging.getLogger('mne').setLevel(logging.WARNING) | |
import xml.etree.ElementTree as ET | |
EPOCH_SEC_SIZE = 30 | |
ann2label = { | |
"Wake|0": 0, | |
"Stage 1 sleep|1": 1, | |
"Stage 2 sleep|2": 2, | |
"Stage 3 sleep|3": 3, | |
"Stage 4 sleep|4": 3, | |
"REM sleep|5": 4 | |
} | |
data_dir = "/x/mnt/hdd/emad_data/mesa/polysomnography/edfs" | |
ann_dir = "/x/mnt/hdd/emad_data/mesa/polysomnography/annotations-events-nsrr" | |
save_dir = "/x/mnt/hdd/emad_data/mesa_preprocessed" | |
filenames = next(os.walk(data_dir))[2] | |
annotation = next(os.walk(ann_dir))[2] | |
edf_fnames = list() | |
ann_fnames = list() | |
for f in filenames: | |
filename, file_extension = os.path.splitext(f) | |
if (file_extension == '.edf'): | |
edf_fnames.append(os.path.join(data_dir, filename + file_extension)) | |
for f in annotation: | |
filename, file_extension = os.path.splitext(f) | |
if (file_extension == '.xml'): | |
ann_fnames.append(os.path.join(ann_dir, filename + file_extension)) | |
ann_fnames = [i for i in ann_fnames if "nsrr" in i] | |
edf_fnames.sort() | |
ann_fnames.sort() | |
edf_fnames = np.asarray(edf_fnames) | |
ann_fnames = np.asarray(ann_fnames) | |
# Initialize a dictionary to store epoch-label mappings | |
epoch_label_map = {} | |
# Iterate through annotation XML files | |
def get_labels(ann_fname): | |
# Parse the XML file | |
tree = ET.parse(ann_fname) | |
root = tree.getroot() | |
# Extract relevant information from the XML | |
for child in root.iter('ScoredEvent'): | |
t1 = child[0].text | |
t2 = child[1].text | |
t3 = float(child[2].text) # Convert start time to float | |
t4 = float(child[3].text) # Convert duration to float | |
if t2 not in ann2label: | |
continue | |
# Calculate the start and end timestamps for the event | |
start_timestamp = t3 | |
end_timestamp = t3 + t4 | |
# Extract the stage label (you can modify this as needed) | |
stage_label = ann2label[t2] | |
# Store the stage label for the corresponding epoch | |
epoch_label_map[(start_timestamp, end_timestamp)] = stage_label | |
return epoch_label_map | |
for file_id in range(len(edf_fnames)): | |
epoch_label_map = {} | |
# print(edf_fnames[file_id]) | |
subject_id = os.path.basename(edf_fnames[file_id]).split("-")[-1].split(".")[0] | |
subject_ann = os.path.basename(ann_fnames[file_id]).split("-")[-2] | |
assert subject_id == subject_ann | |
raw = read_raw_edf(edf_fnames[file_id], preload=True, stim_channel=None, verbose=None) | |
sampling_rate = raw.info['sfreq'] | |
select_ch = ['EKG', 'EOG-L', 'EOG-R', 'EMG', 'EEG1', 'EEG2', 'EEG3'] | |
raw_ch_df = raw.to_data_frame(scaling_time=sampling_rate)[select_ch] | |
raw_ch_df.set_index(np.arange(len(raw_ch_df))) | |
raw_ch = raw_ch_df.values | |
print(raw_ch.shape) | |
epoch_label_map = get_labels(ann_fnames[file_id]) | |
eeg_epochs = [] | |
labels = [] | |
# Iterate through the epochs and extract corresponding EEG data | |
for epoch_start, epoch_end in epoch_label_map: | |
# Convert epoch timestamps to sample indices | |
start_sample = int(epoch_start * sampling_rate) | |
end_sample = int(epoch_end * sampling_rate) | |
# Extract the EEG epoch data | |
eeg_epoch = raw_ch[start_sample:end_sample] | |
n_epochs = len(eeg_epoch) // (EPOCH_SEC_SIZE * sampling_rate) | |
# Get epochs and their corresponding labels | |
x = np.asarray(np.split(eeg_epoch, n_epochs)) #.astype(np.float32) | |
y = [epoch_label_map[(epoch_start, epoch_end)]] * int(n_epochs) | |
eeg_epochs.extend(x) | |
labels.extend(y) | |
# Get epochs and their corresponding labels | |
x = np.array(eeg_epochs).transpose(0, 2, 1) | |
y = np.array(labels) | |
print(x.shape) | |
print(y.shape) | |
assert len(x) == len(y) | |
data_save = dict() | |
data_save["samples"] = torch.from_numpy(x).float() | |
data_save["labels"] = torch.from_numpy(y) | |
torch.save(data_save, os.path.join(save_dir, f"subject_{subject_id}.pt")) | |
print(f" ---------- Done with Subject {subject_id} ---------") | |
raw.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
in line 94, change to: raw_ch_df = raw.to_data_frame(scalings=sampling_rate)[select_ch]