antoniaebner's picture
update code and requirements
9b98ab0
"""
This files includes a XGBoost model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""
# ---------------------------------------------------------------------------------------
# Dependencies
import os
import joblib
import numpy as np
from xgboost import XGBClassifier
from .utils import TASKS
# ---------------------------------------------------------------------------------------
class Tox21XGBClassifier:
"""A XGBoost classifier that assigns a toxicity score to a given SMILES string."""
def __init__(self, seed: int = 42, task_configs: dict | None = None) -> None:
"""Initialize an XGBoost classifier for each of the 12 Tox21 tasks.
Args:
seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42.
task_configs (dict | None, optional): dictionary containing task-specific
hyperparameters. If None, default hyperparameters are used for all tasks.
Defaults to None.
"""
self.tasks = TASKS
self.model = {
task: (
XGBClassifier(random_state=seed, n_jobs=8)
if task_configs is None
else XGBClassifier(
**{
k: v
for k, v in task_configs[task].items()
if k != "var_threshold"
},
random_state=seed,
n_jobs=8,
)
)
for task in self.tasks
}
self.feature_processors = {}
def load_model(self, dir: str) -> None:
"""Loads the model from a given directory
Args:
dir (str): directory to load model from
"""
self.model = joblib.load(os.path.join(dir, "xgb_alltasks.joblib"))
self.feature_processors = joblib.load(
os.path.join(dir, "feature_processors.pkl")
)
def save_model(self, dir: str) -> None:
"""Saves the model to a given directory
Args:
dir (str): directory to save model to
"""
model_path = os.path.join(dir, "xgb_alltasks.joblib")
feature_processor_path = os.path.join(dir, "feature_processors.pkl")
os.makedirs(dir, exist_ok=True)
joblib.dump(self.model, model_path)
joblib.dump(self.feature_processors, feature_processor_path)
def fit(
self, task: str, input_features: np.ndarray, labels: np.ndarray, **kwargs
) -> None:
"""Train XGBoost for a given task
Args:
task (str): task to train
input_features (np.ndarray): training features
labels (np.ndarray): training labels
"""
assert task in self.tasks, f"Unknown task: {task}"
self.model[task].fit(input_features, labels, **kwargs)
def predict(self, task: str, features: np.ndarray) -> np.ndarray:
"""Predicts labels for a given Tox21 target using molecule features
Args:
task (str): the Tox21 target to predict for
features (np.ndarray): molecule features used for prediction
Returns:
np.ndarray: predicted probability for positive class
"""
assert task in self.tasks, f"Unknown task: {task}"
assert (
len(features.shape) == 2
), f"Function expects 2D np.array. Current shape: {features.shape}"
preds = self.model[task].predict_proba(features)
return preds[:, 1]