Created
May 7, 2024 02:25
-
-
Save emadeldeen24/de2b35fa5a26b13e26d40418158a96c9 to your computer and use it in GitHub Desktop.
SHHS2_preprocessing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import numpy as np | |
from mne.io import read_raw_edf | |
import warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
import logging | |
logging.getLogger('mne').setLevel(logging.WARNING) | |
import xml.etree.ElementTree as ET | |
EPOCH_SEC_SIZE = 30 | |
ann2label = { | |
"Wake|0": 0, | |
"Stage 1 sleep|1": 1, | |
"Stage 2 sleep|2": 2, | |
"Stage 3 sleep|3": 3, | |
"Stage 4 sleep|4": 3, | |
"REM sleep|5": 4 | |
} | |
data_dir = "/mnt/data/emad/shhs/polysomnography/edfs/shhs2/" | |
ann_dir = "/mnt/data/emad/shhs/polysomnography/annotations-events-nsrr/shhs2/" | |
save_dir = "/mnt/data/emad/shhs/shhs2_pt/" | |
filenames = next(os.walk(data_dir))[2] | |
annotation = next(os.walk(ann_dir))[2] | |
edf_fnames = list() | |
ann_fnames = list() | |
for f in filenames: | |
filename, file_extension = os.path.splitext(f) | |
if (file_extension == '.edf'): | |
edf_fnames.append(os.path.join(data_dir, filename + file_extension)) | |
for f in annotation: | |
filename, file_extension = os.path.splitext(f) | |
if (file_extension == '.xml'): | |
ann_fnames.append(os.path.join(ann_dir, filename + file_extension)) | |
# Check already preprocessed files: | |
done_subjects = next(os.walk(save_dir))[2] | |
ids = [] | |
for f in done_subjects: | |
filename, file_extension = os.path.splitext(f) | |
ids.append(filename.split("_")[-1]) | |
ids.sort() | |
ann_fnames = [i for i in ann_fnames if "shhs" in i] | |
edf_fnames.sort() | |
ann_fnames.sort() | |
edf_fnames = np.asarray(edf_fnames) | |
ann_fnames = np.asarray(ann_fnames) | |
# Initialize a dictionary to store epoch-label mappings | |
epoch_label_map = {} | |
# Iterate through annotation XML files | |
def get_labels(ann_fname): | |
# Parse the XML file | |
tree = ET.parse(ann_fname) | |
root = tree.getroot() | |
# Extract relevant information from the XML | |
for child in root.iter('ScoredEvent'): | |
t1 = child[0].text | |
t2 = child[1].text | |
t3 = float(child[2].text) # Convert start time to float | |
t4 = float(child[3].text) # Convert duration to float | |
if t2 not in ann2label: | |
continue | |
# Calculate the start and end timestamps for the event | |
start_timestamp = t3 | |
end_timestamp = t3 + t4 | |
# Extract the stage label (you can modify this as needed) | |
stage_label = ann2label[t2] | |
# Store the stage label for the corresponding epoch | |
epoch_label_map[(start_timestamp, end_timestamp)] = stage_label | |
return epoch_label_map | |
for file_id in range(len(edf_fnames)): | |
epoch_label_map = {} | |
# print(edf_fnames[file_id]) | |
subject_id = os.path.basename(edf_fnames[file_id]).split("-")[-1].split(".")[0] | |
subject_ann = os.path.basename(ann_fnames[file_id]).split("-")[-2] | |
assert subject_id == subject_ann | |
if subject_id in ids: | |
continue | |
try: | |
print(f"Preprocessing subject: {subject_id}") | |
raw = read_raw_edf(edf_fnames[file_id], preload=True, stim_channel=None, verbose=None) | |
sampling_rate = raw.info['sfreq'] | |
# channels = raw.ch_names | |
select_ch = ['EEG(sec)', 'ECG', 'EMG', 'EOG(L)', 'EOG(R)', 'EEG'] | |
raw_ch_df = raw.to_data_frame(scaling_time=sampling_rate)[select_ch] | |
raw_ch_df.set_index(np.arange(len(raw_ch_df))) | |
raw_ch = raw_ch_df.values | |
# print(raw_ch.shape) | |
epoch_label_map = get_labels(ann_fnames[file_id]) | |
eeg_epochs = [] | |
labels = [] | |
# Iterate through the epochs and extract corresponding EEG data | |
for epoch_start, epoch_end in epoch_label_map: | |
# Convert epoch timestamps to sample indices | |
start_sample = int(epoch_start * sampling_rate) | |
end_sample = int(epoch_end * sampling_rate) | |
# Extract the EEG epoch data | |
eeg_epoch = raw_ch[start_sample:end_sample] | |
n_epochs = len(eeg_epoch) // (EPOCH_SEC_SIZE * sampling_rate) | |
# Get epochs and their corresponding labels | |
x = np.asarray(np.split(eeg_epoch, n_epochs)) #.astype(np.float32) | |
y = [epoch_label_map[(epoch_start, epoch_end)]] * int(n_epochs) | |
eeg_epochs.extend(x) | |
labels.extend(y) | |
# Get epochs and their corresponding labels | |
x = np.array(eeg_epochs).transpose(0, 2, 1) | |
y = np.array(labels) | |
# print(x.shape) | |
# print(y.shape) | |
assert len(x) == len(y) | |
data_save = dict() | |
data_save["samples"] = torch.from_numpy(x).float() | |
data_save["labels"] = torch.from_numpy(y) | |
torch.save(data_save, os.path.join(save_dir, f"shhs1_{subject_id}.pt")) | |
print(f" ---------- Done with Subject {subject_id} ---------") | |
raw.close() | |
except: | |
print(f"####### ISSUE WITH SUBJECT {subject_id} #########") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment