""" Script for fitting and saving any preprocessing assets, as well as the fitted XGBoost model """ import os import argparse import numpy as np from tabulate import tabulate from sklearn.feature_selection import VarianceThreshold from sklearn.metrics import roc_auc_score from sklearn.preprocessing import StandardScaler from model import Tox21XGBClassifier SEED = 999 DATA_FOLDER = "data/" parser = argparse.ArgumentParser(description="XGBoost Training script for Tox21 dataset") parser.add_argument( "--model_dir", type=str, default="assets", ) def main(args): print("Preprocess train molecules") data_path = os.path.join(DATA_FOLDER, "tox21_data.npz") full_data = np.load(data_path, allow_pickle=True) features = full_data["features"] labels = full_data["labels"] sets = full_data["sets"] # Handle inf/nan features: instead of dropping columns, zero-out entire affected columns # so that VarianceThreshold will remove them later, keeping indices aligned. bad_entries = np.isinf(features) | np.isnan(features) bad_cols = np.any(bad_entries, axis=0) if np.any(bad_cols): features[:, bad_cols] = 0.0 train_val_mask = sets != "test" # TMP fix should be "validation" ? train_X = features[train_val_mask] train_y = labels[train_val_mask] test_mask = sets == "test" val_X = features[test_mask] val_y = labels[test_mask] task_config = { "NR-AR": { "max_depth": 4, "min_child_weight": 1.1005779061921914, "gamma": 0.1317988706679324, "learning_rate": 0.039645108160965156, "subsample": 0.7296241662412439, "colsample_bytree": 0.8021365422870282, "reg_alpha": 3.3237336705963336e-06, "reg_lambda": 0.5602005185114373, "colsample_bylevel": 0.6436881915714322, "max_bin": 320, "grow_policy": "depthwise", "var_threshold": 0.007666987709838448 }, "NR-AR-LBD": { "max_depth": 4, "min_child_weight": 4.1987212703698695, "gamma": 1.2762015931613548, "learning_rate": 0.15154599977311695, "subsample": 0.6695940698634157, "colsample_bytree": 0.7739932636137854, "reg_alpha": 0.07898626960219088, "reg_lambda": 8.571012949754111, "colsample_bylevel": 0.9853057670318977, "max_bin": 512, "grow_policy": "lossguide", "var_threshold": 0.00037667540735397795 }, "NR-AhR": { "max_depth": 5, "min_child_weight": 6.689827023187083, "gamma": 0.05246277760115231, "learning_rate": 0.04756606141238733, "subsample": 0.8679211962117436, "colsample_bytree": 0.6095873089337578, "reg_alpha": 2.9267916989096844e-05, "reg_lambda": 0.16597411475484836, "colsample_bylevel": 0.6109587378961451, "max_bin": 192, "grow_policy": "lossguide", "var_threshold": 0.006450426707708987 }, "NR-Aromatase": { "max_depth": 3, "min_child_weight": 3.2876314247596152, "gamma": 0.19699266508924895, "learning_rate": 0.05088088932843542, "subsample": 0.7865649204014827, "colsample_bytree": 0.7251861382401115, "reg_alpha": 1.5663141562519894e-05, "reg_lambda": 0.8079227014059855, "colsample_bylevel": 0.6264563203168154, "max_bin": 320, "grow_policy": "lossguide", "var_threshold": 0.008210794229202779 }, "NR-ER": { "max_depth": 4, "min_child_weight": 5.780102015649284, "gamma": 1.4129142474001934, "learning_rate": 0.030962338755374925, "subsample": 0.6495287204129598, "colsample_bytree": 0.6052286799267346, "reg_alpha": 2.350761568396455e-08, "reg_lambda": 0.09630529926179951, "colsample_bylevel": 0.7431813327243276, "max_bin": 384, "grow_policy": "lossguide", "var_threshold": 0.0023810780862365695 }, "NR-ER-LBD": { "max_depth": 5, "min_child_weight": 9.173052917805649, "gamma": 1.0722539699322629, "learning_rate": 0.04237749698413915, "subsample": 0.7066072339657229, "colsample_bytree": 0.6813795582720684, "reg_alpha": 0.00023207537137377197, "reg_lambda": 15.088634424806914, "colsample_bylevel": 0.7799437417755278, "max_bin": 384, "grow_policy": "depthwise", "var_threshold": 0.0019169350680113165 }, "NR-PPAR-gamma": { "max_depth": 6, "min_child_weight": 5.174007598815524, "gamma": 1.9912192366255241, "learning_rate": 0.05540828755212913, "subsample": 0.6903953157523113, "colsample_bytree": 0.8663027348173384, "reg_alpha": 2.083339410970234e-08, "reg_lambda": 0.015396790332761562, "colsample_bylevel": 0.9751745752733803, "max_bin": 320, "grow_policy": "lossguide", "var_threshold": 0.0029616070252124786 }, "SR-ARE": { "max_depth": 7, "min_child_weight": 9.1659526731455, "gamma": 0.697265411436678, "learning_rate": 0.06570769871964029, "subsample": 0.9905868520803529, "colsample_bytree": 0.9320468198902392, "reg_alpha": 0.0015832053017691588, "reg_lambda": 0.05920338550334178, "colsample_bylevel": 0.9881491817036743, "max_bin": 128, "grow_policy": "lossguide", "var_threshold": 0.002817440527458996 }, "SR-ATAD5": { "max_depth": 8, "min_child_weight": 3.840348891355251, "gamma": 1.6154505675458388, "learning_rate": 0.13247082849598005, "subsample": 0.8051455662822469, "colsample_bytree": 0.8812075918541051, "reg_alpha": 1.0831755964182738e-08, "reg_lambda": 27.095693383578947, "colsample_bylevel": 0.636617995280427, "max_bin": 256, "grow_policy": "depthwise", "var_threshold": 0.009669430411280284 }, "SR-HSE": { "max_depth": 9, "min_child_weight": 6.413184249228777, "gamma": 1.033704331418744, "learning_rate": 0.05274739499143931, "subsample": 0.8865620043291726, "colsample_bytree": 0.6816866072800449, "reg_alpha": 0.058835365152010946, "reg_lambda": 0.020754661410877756, "colsample_bylevel": 0.9110208090854688, "max_bin": 512, "grow_policy": "lossguide", "var_threshold": 0.005674926071804129 }, "SR-MMP": { "max_depth": 5, "min_child_weight": 9.817728618387365, "gamma": 1.174192311657815, "learning_rate": 0.0469463693712702, "subsample": 0.7551958380501903, "colsample_bytree": 0.7909988895785574, "reg_alpha": 0.00015815798249652454, "reg_lambda": 0.07975430070894152, "colsample_bylevel": 0.6649592956153568, "max_bin": 128, "grow_policy": "depthwise", "var_threshold": 0.006024127982297082 }, "SR-p53": { "max_depth": 8, "min_child_weight": 5.038486734836349, "gamma": 1.807085258740345, "learning_rate": 0.1096533837056875, "subsample": 0.71588646279992, "colsample_bytree": 0.8086559814485024, "reg_alpha": 3.864250735509029e-08, "reg_lambda": 0.03548737332001143, "colsample_bylevel": 0.7740614694930106, "max_bin": 128, "grow_policy": "depthwise", "var_threshold": 0.008637178477182731 }, } results = {} for i, task in enumerate(task_config.keys()): npos = np.nansum(train_y[:, i]) nneg = np.sum(~np.isnan(train_y[:, i])) - npos task_config[task].update({ "tree_method": "hist", "n_estimators": 10_000, "early_stopping_rounds": 50, "eval_metric": "auc", "scale_pos_weight": nneg / max(npos, 1), "device": "cpu", }) model = Tox21XGBClassifier(seed=SEED, task_configs=task_config) print("Start training.") for i, task in enumerate(model.tasks): #print(model.model[task]) # Training ----------------------- task_labels = train_y[:, i] label_mask = ~np.isnan(task_labels) task_data = train_X[label_mask] task_labels = task_labels[label_mask].astype(int) # Remove low variance features and scale var_thresh = VarianceThreshold(threshold=task_config[task]["var_threshold"]) task_data = var_thresh.fit_transform(task_data) scaler = StandardScaler() task_data = scaler.fit_transform(task_data) model.feature_processors[task] = { "selector": var_thresh, "scaler": scaler, } # From X_train split 10% for an early stopping validation set np.random.seed(SEED) random_numbers = np.random.rand(task_data.shape[0]) es_val_mask = random_numbers < 0.1 es_train_mask = random_numbers >= 0.1 X_es_val, y_es_val = task_data[es_val_mask], task_labels[es_val_mask] X_es_train, y_es_train = task_data[es_train_mask], task_labels[es_train_mask] print(f"Fit task {task} using {sum(label_mask)} samples and {task_data.shape[1]} features") model.fit(task, X_es_train, y_es_train, eval_set=[(X_es_val, y_es_val)], verbose=False) # Evaluation ----------------------- val_task_labels = val_y[:, i] val_label_mask = ~np.isnan(val_task_labels) val_task_labels = val_task_labels[val_label_mask].astype(int) val_task_data = val_X[val_label_mask] val_task_data = model.feature_processors[task]["selector"].transform(val_task_data) val_task_data = model.feature_processors[task]["scaler"].transform(val_task_data) # Evaluate model pred = model.predict(task, val_task_data) results[task] = [roc_auc_score(y_true=val_task_labels, y_score=pred)] print(f"Save model under {args.model_dir}") model.save_model(args.model_dir) print("Results:") print(tabulate(results, headers="keys")) print("Average: ", sum([val[0] for val in results.values()]) / len(results)) if __name__ == "__main__": args = parser.parse_args() main(args)