File size: 3,578 Bytes
87e7d05
 
 
 
 
 
 
 
 
 
 
 
 
 
9b98ab0
87e7d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b98ab0
 
 
 
 
 
 
 
 
 
 
 
87e7d05
 
 
 
 
 
 
 
 
 
 
 
9b98ab0
 
 
87e7d05
 
 
 
 
 
 
 
 
 
 
 
 
 
9b98ab0
 
 
87e7d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
This files includes a XGBoost model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""

# ---------------------------------------------------------------------------------------
# Dependencies
import os
import joblib

import numpy as np
from xgboost import XGBClassifier

from .utils import TASKS


# ---------------------------------------------------------------------------------------
class Tox21XGBClassifier:
    """A XGBoost classifier that assigns a toxicity score to a given SMILES string."""

    def __init__(self, seed: int = 42, task_configs: dict | None = None) -> None:
        """Initialize an XGBoost classifier for each of the 12 Tox21 tasks.

        Args:
            seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42.
            task_configs (dict | None, optional): dictionary containing task-specific
                hyperparameters. If None, default hyperparameters are used for all tasks.
                Defaults to None.
        """
        self.tasks = TASKS
        self.model = {
            task: (
                XGBClassifier(random_state=seed, n_jobs=8)
                if task_configs is None
                else XGBClassifier(
                    **{
                        k: v
                        for k, v in task_configs[task].items()
                        if k != "var_threshold"
                    },
                    random_state=seed,
                    n_jobs=8,
                )
            )
            for task in self.tasks
        }
        self.feature_processors = {}

    def load_model(self, dir: str) -> None:
        """Loads the model from a given directory

        Args:
            dir (str): directory to load model from
        """
        self.model = joblib.load(os.path.join(dir, "xgb_alltasks.joblib"))
        self.feature_processors = joblib.load(
            os.path.join(dir, "feature_processors.pkl")
        )

    def save_model(self, dir: str) -> None:
        """Saves the model to a given directory

        Args:
            dir (str): directory to save model to
        """
        model_path = os.path.join(dir, "xgb_alltasks.joblib")
        feature_processor_path = os.path.join(dir, "feature_processors.pkl")
        os.makedirs(dir, exist_ok=True)

        joblib.dump(self.model, model_path)
        joblib.dump(self.feature_processors, feature_processor_path)

    def fit(
        self, task: str, input_features: np.ndarray, labels: np.ndarray, **kwargs
    ) -> None:
        """Train XGBoost for a given task

        Args:
            task (str): task to train
            input_features (np.ndarray): training features
            labels (np.ndarray): training labels
        """
        assert task in self.tasks, f"Unknown task: {task}"
        self.model[task].fit(input_features, labels, **kwargs)

    def predict(self, task: str, features: np.ndarray) -> np.ndarray:
        """Predicts labels for a given Tox21 target using molecule features

        Args:
            task (str): the Tox21 target to predict for
            features (np.ndarray): molecule features used for prediction

        Returns:
            np.ndarray: predicted probability for positive class
        """
        assert task in self.tasks, f"Unknown task: {task}"
        assert (
            len(features.shape) == 2
        ), f"Function expects 2D np.array. Current shape: {features.shape}"
        preds = self.model[task].predict_proba(features)
        return preds[:, 1]