| import os |
| from pathlib import Path |
|
|
| import h5py |
| import numpy as np |
| import pandas as pd |
| import scipy.signal as signal |
| from joblib import Parallel, delayed |
| from scipy.signal import iirnotch |
| from tqdm import tqdm |
|
|
| sequence_to_seconds = lambda seq_len, fs: seq_len / fs |
|
|
|
|
| |
| def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0): |
| b, a = iirnotch(notch_freq, Q, fs) |
| out = np.zeros_like(data) |
| for ch in range(data.shape[1]): |
| out[:, ch] = signal.filtfilt(b, a, data[:, ch]) |
| return out |
|
|
|
|
| def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4): |
| nyq = 0.5 * fs |
| low = lowcut / nyq |
| high = highcut / nyq |
| b, a = signal.butter(order, [low, high], btype="bandpass") |
| out = np.zeros_like(emg) |
| for c in range(emg.shape[1]): |
| out[:, c] = signal.filtfilt(b, a, emg[:, c]) |
| return out |
|
|
|
|
| |
| def process_emg_features(emg, window_size=1000, stride=500): |
| segs, lbls = [], [] |
| N = len(emg) |
| for start in range(0, N, stride): |
| end = start + window_size |
| if end > N: |
| continue |
| win = emg[start:end] |
| segs.append(win) |
| return np.array(segs) |
|
|
|
|
| def process_one_recording(file_path, fs=2000.0, window_size=1000, stride=500): |
| """ |
| Process a single recording file to extract EMG features and labels |
| as to be used in the main pipeline with parallel processing. |
| """ |
| with h5py.File(file_path, "r") as f: |
| grp = f["emg2pose"] |
| data = grp["timeseries"] |
| emg = data["emg"][:].astype(np.float32) |
|
|
| |
| emg_filt = bandpass_filter_emg(emg, 20, 450, fs=fs) |
| emg_filt = notch_filter(emg_filt, 50, 30, fs=fs) |
|
|
| |
| mu = emg_filt.mean(axis=0) |
| sd = emg_filt.std(axis=0, ddof=1) |
| sd[sd == 0] = 1.0 |
| emg_z = (emg_filt - mu) / sd |
|
|
| |
| segs = process_emg_features(emg_z, window_size, stride) |
|
|
| return segs |
|
|
|
|
| |
| def main(): |
| import argparse |
|
|
| args = argparse.ArgumentParser(description="Process EMG data from DB5.") |
| args.add_argument("--data_dir", type=str) |
| args.add_argument("--save_dir", type=str) |
| args.add_argument( |
| "--seq_len", type=int, help="Size of the window in samples for segmentation." |
| ) |
| args.add_argument( |
| "--stride", type=int, help="Step size between windows in samples for segmentation." |
| ) |
| args.add_argument( |
| "--subsample", type=float, default=1.0, help="Whether to subsample the data" |
| ) |
| args.add_argument( |
| "--n_jobs", |
| type=int, |
| default=-1, |
| help="Number of parallel jobs to run. -1 means using all available cores.", |
| ) |
| args.add_argument( |
| "--seed", type=int, default=42, help="Random seed for reproducibility." |
| ) |
| args = args.parse_args() |
|
|
| data_dir = args.data_dir |
| save_dir = args.save_dir |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| fs = 2000.0 |
| window_size, stride = args.seq_len, args.stride |
|
|
| window_seconds = sequence_to_seconds(window_size, fs) |
| print(f"Window size: {window_size} samples ({window_seconds:.2f} seconds)") |
|
|
| df = pd.read_csv(os.path.join(data_dir, "metadata.csv")) |
| df = df.groupby("split").apply( |
| lambda x: ( |
| x.sample(frac=args.subsample, random_state=args.seed) |
| if args.subsample < 1.0 |
| else x |
| ) |
| ) |
| df.reset_index(drop=True, inplace=True) |
|
|
| splits = {} |
| for split, df_ in df.groupby("split"): |
| sessions = list(df_.filename) |
| splits[split] = [ |
| Path(data_dir).expanduser().joinpath(f"{session}.hdf5") |
| for session in sessions |
| ] |
|
|
| all_data = {"train": [], "val": [], "test": []} |
|
|
| for split, files in splits.items(): |
| |
| results = Parallel(n_jobs=args.n_jobs)( |
| delayed(process_one_recording)(file_path, fs, window_size, stride) |
| for file_path in tqdm(files, desc=f"Processing {split} files") |
| ) |
| |
| for segs in tqdm(results, desc=f"Collecting {split} data"): |
| all_data[split].append(segs) |
|
|
| |
| X = np.concatenate(all_data[split], axis=0) |
|
|
| |
| X = X.transpose(0, 2, 1) |
|
|
| |
| with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as hf: |
| hf.create_dataset("data", data=X) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|