antoniaebner commited on
Commit
87e7d05
·
1 Parent(s): b0daa87

add Sohvi's code

Browse files
Files changed (12) hide show
  1. .gitignore +1 -0
  2. README.md +103 -0
  3. app.py +78 -0
  4. assets/tox_smarts.json +0 -0
  5. predict.py +63 -0
  6. requirements.txt +10 -0
  7. src/__init__.py +0 -0
  8. src/model.py +90 -0
  9. src/preprocess.py +263 -0
  10. src/push_assets.py +12 -0
  11. src/train.py +294 -0
  12. src/utils.py +432 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ */__pycache__/*
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Tox21 XGBoost Classifier
3
+ emoji: 🚀
4
+ colorFrom: green
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ short_description: XGBoost baseline classifier for Tox21
10
+ ---
11
+
12
+ # Tox21 XGBoost Classifier
13
+
14
+ This repository hosts a Hugging Face Space that provides an examplary API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/tschouis/tox21_leaderboard).
15
+
16
+ In this example, we train a XGBoost classifier on the Tox21 targets and save the trained model in the `assets/` folder.
17
+
18
+ **Important:** For leaderboard submission, your Space does not need to include training code. It only needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a prediction dictionary as output, with SMILES and targets as keys. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.
19
+
20
+ # Repository Structure
21
+ - `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
22
+ - `app.py` - FastAPI application wrapper (can be used as-is).
23
+
24
+ - `src/` - Core model & preprocessing logic:
25
+ - `data.py` - SMILES preprocessing pipeline
26
+ - `model.py` - XGBoost classifier wrapper
27
+ - `train.py` - Script to train the classifier
28
+ - `utils.py` – Constants and Helper functions
29
+
30
+ # Quickstart with Spaces
31
+
32
+ You can easily adapt this project in your own Hugging Face account:
33
+
34
+ - Open this Space on Hugging Face.
35
+
36
+ - Click "Duplicate this Space" (top-right corner).
37
+
38
+ - Modify `src/` for your preprocessing pipeline and model class
39
+
40
+ - Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.
41
+
42
+ That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.
43
+
44
+ # Installation
45
+ To run (and train) the XGBoost, clone the repository and install dependencies:
46
+
47
+ ```bash
48
+ git clone https://huggingface.co/spaces/tschouis/tox21_xgboost_classifier
49
+ cd tox_21_xgb_classifier
50
+
51
+ conda create -n tox21_xgb_cls python=3.11
52
+ conda activate tox21_xgb_cls
53
+ pip install -r requirements.txt
54
+ ```
55
+
56
+ # Training
57
+
58
+ To train the XGBoost model from scratch:
59
+
60
+ ```bash
61
+ python -m src/train.py
62
+ ```
63
+
64
+ This will:
65
+
66
+ 1. Load and preprocess the Tox21 training dataset.
67
+ 2. Train a XGBoost classifier.
68
+ 3. Save the trained model to the assets/ folder.
69
+ 4. Evaluate the trained XGBoost classifier on the validation split.
70
+
71
+
72
+ # Inference
73
+
74
+ For inference, you only need `predict.py`.
75
+
76
+ Example usage inside Python:
77
+
78
+ ```python
79
+ from predict import predict
80
+
81
+ smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
82
+ results = predict(smiles_list)
83
+
84
+ print(results)
85
+ ```
86
+
87
+ The output will be a nested dictionary in the format:
88
+
89
+ ```python
90
+ {
91
+ "CCO": {"target1": 0, "target2": 1, ..., "target12": 0},
92
+ "c1ccccc1": {"target1": 1, "target2": 0, ..., "target12": 1},
93
+ "CC(=O)O": {"target1": 0, "target2": 0, ..., "target12": 0}
94
+ }
95
+ ```
96
+
97
+ # Notes
98
+
99
+ - Only adapting `predict.py` for your model inference is required for leaderboard submission.
100
+
101
+ - Training (`src/train.py`) is provided for reproducibility.
102
+
103
+ - Preprocessing (here inside `src/data.py`) must be applied at inference time, not just training.
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the main entry point for the FastAPI application.
3
+ The app handles the request to predict toxicity for a list of SMILES strings.
4
+ """
5
+
6
+ # ---------------------------------------------------------------------------------------
7
+ # Dependencies and global variable definition
8
+ import os
9
+ from typing import List, Dict, Optional
10
+ from fastapi import FastAPI, Header, HTTPException
11
+ from pydantic import BaseModel, Field
12
+
13
+ from predict import predict as predict_func
14
+
15
+ API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Request(BaseModel):
20
+ smiles: List[str] = Field(min_items=1, max_items=1000)
21
+
22
+
23
+ class Response(BaseModel):
24
+ predictions: dict
25
+ model_info: Dict[str, str] = {}
26
+
27
+
28
+ app = FastAPI(title="toxicity-api")
29
+
30
+
31
+ @app.get("/")
32
+ def root():
33
+ return {
34
+ "message": "Toxicity Prediction API",
35
+ "endpoints": {
36
+ "/metadata": "GET - API metadata and capabilities",
37
+ "/healthz": "GET - Health check",
38
+ "/predict": "POST - Predict toxicity for SMILES",
39
+ },
40
+ "usage": "Send POST to /predict with {'smiles': ['your_smiles_here']} and Authorization header",
41
+ }
42
+
43
+
44
+ @app.get("/metadata")
45
+ def metadata():
46
+ return {
47
+ "name": "AwesomeTox",
48
+ "version": "1.0.0",
49
+ "max_batch_size": 256,
50
+ "tox_endpoints": [
51
+ "NR-AR",
52
+ "NR-AR-LBD",
53
+ "NR-AhR",
54
+ "NR-Aromatase",
55
+ "NR-ER",
56
+ "NR-ER-LBD",
57
+ "NR-PPAR-gamma",
58
+ "SR-ARE",
59
+ "SR-ATAD5",
60
+ "SR-HSE",
61
+ "SR-MMP",
62
+ "SR-p53",
63
+ ],
64
+ }
65
+
66
+
67
+ @app.get("/healthz")
68
+ def healthz():
69
+ return {"ok": True}
70
+
71
+
72
+ @app.post("/predict", response_model=Response)
73
+ def predict(request: Request):
74
+ predictions = predict_func(request.smiles)
75
+ return {
76
+ "predictions": predictions,
77
+ "model_info": {"name": "random_clf", "version": "1.0.0"},
78
+ }
assets/tox_smarts.json ADDED
The diff for this file is too large to render. See raw diff
 
predict.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a predict function for the Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ from collections import defaultdict
10
+
11
+ import numpy as np
12
+
13
+ from src.model import Tox21XGBClassifier
14
+ from src.preprocess import create_descriptors
15
+
16
+ # ---------------------------------------------------------------------------------------
17
+
18
+
19
+ def predict(smiles_list: list[str]) -> dict[str, dict[str, float]]:
20
+ """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for
21
+ any molecule that could not be cleaned.
22
+
23
+ Args:
24
+ smiles_list (list[str]): list of SMILES strings
25
+
26
+ Returns:
27
+ dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}}
28
+ """
29
+ print(f"Received {len(smiles_list)} SMILES strings")
30
+ # preprocessing pipeline
31
+ features, mol_mask = create_descriptors(
32
+ smiles_list,
33
+ )
34
+ print(f"Created {features.shape[1]} descriptors for the molecules.")
35
+ print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning. All predictions for these will be set to 0.0.")
36
+
37
+ # setup model
38
+ model = Tox21XGBClassifier(seed=42)
39
+ model_dir = "assets/"
40
+ model.load_model(model_dir)
41
+ print(f"Loaded model and feature processors from {model_dir}")
42
+
43
+ # make predictions
44
+ predictions = defaultdict(dict)
45
+ feat_indices = np.cumsum(mol_mask) - 1
46
+ for target in model.tasks:
47
+ feature_processors = model.feature_processors[target]
48
+ task_features = feature_processors['selector'].transform(features)
49
+ task_features = feature_processors['scaler'].transform(task_features)
50
+ target_pred = model.predict(target, task_features)
51
+ for smiles, is_clean, i in zip(smiles_list, mol_mask, feat_indices):
52
+ predictions[smiles][target] = float(target_pred[i]) if is_clean else 0.0
53
+ return predictions
54
+
55
+ if __name__ == "__main__":
56
+ # simple test
57
+ test_smiles = [
58
+ "CCO",
59
+ "CCN",
60
+ "invalid_smiles",
61
+ ]
62
+ preds = predict(test_smiles)
63
+ print(preds)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ statsmodels
4
+ rdkit
5
+ numpy
6
+ scikit-learn==1.7.1
7
+ joblib
8
+ tabulate
9
+ datasets
10
+ xgboost==3.0.5
src/__init__.py ADDED
File without changes
src/model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a XGBoost model for Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ import os
10
+ import joblib
11
+
12
+ import numpy as np
13
+ from xgboost import XGBClassifier
14
+
15
+ from utils import TASKS
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Tox21XGBClassifier:
20
+ """A XGBoost classifier that assigns a toxicity score to a given SMILES string."""
21
+
22
+ def __init__(self, seed: int = 42, task_configs: dict | None = None) -> None:
23
+ """Initialize an XGBoost classifier for each of the 12 Tox21 tasks.
24
+
25
+ Args:
26
+ seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42.
27
+ task_configs (dict | None, optional): dictionary containing task-specific
28
+ hyperparameters. If None, default hyperparameters are used for all tasks.
29
+ Defaults to None.
30
+ """
31
+ self.tasks = TASKS
32
+ self.model = {
33
+ task: XGBClassifier(random_state=seed, n_jobs=8) if task_configs is None
34
+ else XGBClassifier(
35
+ **{k: v for k, v in task_configs[task].items() if k != 'var_threshold'},
36
+ random_state=seed, n_jobs=8
37
+ )
38
+ for task in self.tasks
39
+ }
40
+ self.feature_processors = {}
41
+
42
+ def load_model(self, dir: str) -> None:
43
+ """Loads the model from a given directory
44
+
45
+ Args:
46
+ dir (str): directory to load model from
47
+ """
48
+ self.model = joblib.load(os.path.join(dir, "xgb_alltasks.joblib"))
49
+ self.feature_processors = joblib.load(os.path.join(dir, "feature_processors.pkl"))
50
+
51
+ def save_model(self, dir: str) -> None:
52
+ """Saves the model to a given directory
53
+
54
+ Args:
55
+ dir (str): directory to save model to
56
+ """
57
+ model_path = os.path.join(dir, "xgb_alltasks.joblib")
58
+ feature_processor_path = os.path.join(dir, "feature_processors.pkl")
59
+ os.makedirs(dir, exist_ok=True)
60
+
61
+ joblib.dump(self.model, model_path)
62
+ joblib.dump(self.feature_processors, feature_processor_path)
63
+
64
+ def fit(self, task: str, input_features: np.ndarray, labels: np.ndarray, **kwargs) -> None:
65
+ """Train XGBoost for a given task
66
+
67
+ Args:
68
+ task (str): task to train
69
+ input_features (np.ndarray): training features
70
+ labels (np.ndarray): training labels
71
+ """
72
+ assert task in self.tasks, f"Unknown task: {task}"
73
+ self.model[task].fit(input_features, labels, **kwargs)
74
+
75
+ def predict(self, task: str, features: np.ndarray) -> np.ndarray:
76
+ """Predicts labels for a given Tox21 target using molecule features
77
+
78
+ Args:
79
+ task (str): the Tox21 target to predict for
80
+ features (np.ndarray): molecule features used for prediction
81
+
82
+ Returns:
83
+ np.ndarray: predicted probability for positive class
84
+ """
85
+ assert task in self.tasks, f"Unknown task: {task}"
86
+ assert (
87
+ len(features.shape) == 2
88
+ ), f"Function expects 2D np.array. Current shape: {features.shape}"
89
+ preds = self.model[task].predict_proba(features)
90
+ return preds[:, 1]
src/preprocess.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes functions to create molecular descriptors.
3
+ As an input it takes a list of SMILES and it outputs a numpy array of descriptors.
4
+ """
5
+
6
+ import json
7
+ import argparse
8
+
9
+ import numpy as np
10
+
11
+ from datasets import load_dataset
12
+
13
+ from rdkit import Chem, DataStructs
14
+ from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
15
+ from rdkit.Chem.rdchem import Mol
16
+
17
+ from utils import (
18
+ TASKS,
19
+ KNOWN_DESCR,
20
+ HF_TOKEN,
21
+ USED_200_DESCR,
22
+ Standardizer,
23
+ )
24
+
25
+ parser = argparse.ArgumentParser(
26
+ description="Data preprocessing script for the Tox21 dataset"
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--save_folder",
31
+ type=str,
32
+ default="data/",
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--use_hf",
37
+ type=int,
38
+ default=0,
39
+ )
40
+
41
+ parser.add_argument(
42
+ "--tox_smarts_filepath",
43
+ type=str,
44
+ default="assets/tox_smarts.json",
45
+ )
46
+
47
+ def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
48
+ """This function creates cleaned RDKit mol objects from a list of SMILES.
49
+ Args:
50
+ smiles (list[str]): list of SMILES
51
+ Returns:
52
+ list[Mol]: list of cleaned molecules
53
+ np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` atindex `i` could not be cleaned and was removed.
54
+ """
55
+ sm = Standardizer(canon_taut=True)
56
+
57
+ clean_mol_mask = list()
58
+ mols = list()
59
+ for i, smile in enumerate(smiles):
60
+ mol = Chem.MolFromSmiles(smile)
61
+ standardized_mol, _ = sm.standardize_mol(mol)
62
+ is_cleaned = standardized_mol is not None
63
+ clean_mol_mask.append(is_cleaned)
64
+ if not is_cleaned:
65
+ continue
66
+ can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
67
+ mols.append(can_mol)
68
+
69
+ return mols, np.array(clean_mol_mask)
70
+
71
+
72
+ def create_ecfp_fps(mols: list[Mol], radius=None, fpsize=None) -> np.ndarray:
73
+ """This function ECFP fingerprints for a list of molecules.
74
+ Args:
75
+ mols (list[Mol]): list of molecules
76
+ Returns:
77
+ np.ndarray: ECFP fingerprints of molecules
78
+ """
79
+ ecfps = list()
80
+
81
+ kwargs = {}
82
+ if not fpsize is None:
83
+ kwargs["fpSize"] = fpsize
84
+ if not radius is None:
85
+ kwargs["radius"] = radius
86
+ for mol in mols:
87
+ gen = rdFingerprintGenerator.GetMorganGenerator(countSimulation=True, **kwargs)
88
+ fp_sparse_vec = gen.GetCountFingerprint(mol)
89
+
90
+ fp = np.zeros((0,), np.int8)
91
+ DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
92
+
93
+ ecfps.append(fp)
94
+
95
+ return np.array(ecfps)
96
+
97
+
98
+ def create_maccs_keys(mols: list[Mol]) -> np.ndarray:
99
+ maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols]
100
+ return np.array(maccs)
101
+
102
+
103
+ def get_tox_patterns(filepath: str):
104
+ """This calculates tox features defined in tox_smarts.json.
105
+ Args:
106
+ mols: A list of Mol
107
+ n_jobs: If >1 multiprocessing is used
108
+ """
109
+ # load patterns
110
+ with open(filepath) as f:
111
+ smarts_list = [s[1] for s in json.load(f)]
112
+
113
+ # Code does not work for this case
114
+ assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0
115
+
116
+ # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first
117
+ # and then use them for all molecules. This gives a huge speedup over existing code.
118
+ # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value
119
+ all_patterns = []
120
+ for smarts in smarts_list:
121
+ patterns = [] # list of smarts-patterns
122
+ # value for each of the patterns above. Negates the values of the above later.
123
+ negations = []
124
+
125
+ if " AND " in smarts:
126
+ smarts = smarts.split(" AND ")
127
+ merge_any = False # If an ' AND ' is found all 'subsmarts' have to match
128
+ else:
129
+ # If there is an ' OR ' present it's enough is any of the 'subsmarts' match.
130
+ # This also accumulates smarts where neither ' OR ' nor ' AND ' occur
131
+ smarts = smarts.split(" OR ")
132
+ merge_any = True
133
+
134
+ # for all subsmarts check if they are preceded by 'NOT '
135
+ for s in smarts:
136
+ neg = s.startswith("NOT ")
137
+ if neg:
138
+ s = s[4:]
139
+ patterns.append(Chem.MolFromSmarts(s))
140
+ negations.append(neg)
141
+
142
+ all_patterns.append((patterns, negations, merge_any))
143
+ return all_patterns
144
+
145
+
146
+ def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray:
147
+ """Matches the tox patterns against a molecule. Returns a boolean array"""
148
+ tox_data = []
149
+ for mol in mols:
150
+ mol_features = []
151
+ for patts, negations, merge_any in patterns:
152
+ matches = [mol.HasSubstructMatch(p) for p in patts]
153
+ matches = [m != n for m, n in zip(matches, negations)]
154
+ if merge_any:
155
+ pres = any(matches)
156
+ else:
157
+ pres = all(matches)
158
+ mol_features.append(pres)
159
+
160
+ tox_data.append(np.array(mol_features))
161
+
162
+ return np.array(tox_data)
163
+
164
+
165
+ def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
166
+ """This function creates RDKit descriptors for a list of molecules.
167
+ Args:
168
+ mols (list[Mol]): list of molecules
169
+ Returns:
170
+ np.ndarray: RDKit descriptors of molecules
171
+ """
172
+ rdkit_descriptors = list()
173
+
174
+ for mol in mols:
175
+ descrs = []
176
+ for _, descr_calc_fn in Descriptors._descList:
177
+ descrs.append(descr_calc_fn(mol))
178
+
179
+ descrs = np.array(descrs)
180
+ descrs = descrs[USED_200_DESCR]
181
+ rdkit_descriptors.append(descrs)
182
+
183
+ return np.array(rdkit_descriptors)
184
+
185
+
186
+ def create_descriptors(
187
+ smiles,
188
+ ):
189
+ print(f"Preprocess {len(smiles)} molecules")
190
+
191
+ # Create cleanded rdkit mol objects
192
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
193
+ print("Cleaned molecules")
194
+
195
+ tox_patterns = get_tox_patterns("assets/tox_smarts.json")
196
+
197
+ # Create fingerprints and descriptors
198
+ ecfps = create_ecfp_fps(mols, radius=3, fpsize=8192)
199
+ print("Created ECFP fingerprints")
200
+
201
+ tox = create_tox_features(mols, tox_patterns)
202
+ print("Created Tox features")
203
+
204
+ maccs = create_maccs_keys(mols)
205
+ print("Created MACCS keys")
206
+
207
+ rdkit_descrs = create_rdkit_descriptors(mols)
208
+ print("Created RDKit descriptors")
209
+
210
+ features = np.concatenate((ecfps, tox, maccs, rdkit_descrs), axis=1)
211
+ return features, clean_mol_mask
212
+
213
+ def fill(features, mask, value=np.nan):
214
+ n_mols = len(mask)
215
+ n_features = features.shape[1]
216
+
217
+ data = np.zeros(shape=(n_mols, n_features))
218
+ data.fill(value)
219
+ data[~mask] = features
220
+ return data
221
+
222
+ def preprocess_tox21():
223
+
224
+ splits = ["train", "validation"]
225
+ ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
226
+
227
+ all_features, all_labels, all_split = [], [], []
228
+
229
+ for split in splits:
230
+
231
+ print(f"Preprocess {split} molecules")
232
+ smiles = list(ds[split]["smiles"])
233
+
234
+ features, mol_mask = create_descriptors(
235
+ smiles,
236
+ )
237
+ print(f"Created {features.shape[1]} descriptors for {len(smiles)} molecules.")
238
+ print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning.")
239
+
240
+ labels = []
241
+ for task in TASKS:
242
+ datasplit = ds[split].to_pandas() if args.use_hf else ds[split]
243
+ labels.append(datasplit[task].to_numpy())
244
+ labels = np.stack(labels, axis=1)
245
+
246
+ all_features.append(features)
247
+ all_labels.append(labels)
248
+ all_split.append([split] * len(smiles))
249
+
250
+ save_path = f"{args.save_folder}/tox21_data.npz"
251
+ with open(save_path, "wb") as f:
252
+ np.savez_compressed(
253
+ f,
254
+ features=all_features,
255
+ labels=all_labels,
256
+ splits=all_split,
257
+ )
258
+ print(f"Saved preprocessed data to {save_path}")
259
+
260
+
261
+ if __name__ == "__main__":
262
+ args = parser.parse_args()
263
+ preprocess_tox21()
src/push_assets.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ from .utils import HF_TOKEN
3
+
4
+ api = HfApi()
5
+
6
+ api.upload_folder(
7
+ folder_path="assets/",
8
+ path_in_repo="assets",
9
+ repo_id="tschouis/tox21_xgboost_classifier",
10
+ repo_type="space",
11
+ token=HF_TOKEN,
12
+ )
src/train.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for fitting and saving any preprocessing assets, as well as the fitted XGBoost model
3
+ """
4
+
5
+ import os
6
+ import argparse
7
+
8
+ import numpy as np
9
+
10
+ from tabulate import tabulate
11
+
12
+ from sklearn.feature_selection import VarianceThreshold
13
+ from sklearn.metrics import roc_auc_score
14
+ from sklearn.preprocessing import StandardScaler
15
+
16
+ from model import Tox21XGBClassifier
17
+
18
+ SEED = 999
19
+ DATA_FOLDER = "data/"
20
+
21
+ parser = argparse.ArgumentParser(description="XGBoost Training script for Tox21 dataset")
22
+
23
+ parser.add_argument(
24
+ "--model_dir",
25
+ type=str,
26
+ default="assets",
27
+ )
28
+
29
+ def main(args):
30
+ print("Preprocess train molecules")
31
+ data_path = os.path.join(DATA_FOLDER, "tox21_data.npz")
32
+ data_path = "/system/user/studentwork/anebner/tox21_leaderboard/tox21_baselines/data_featset_new/tox21_descriptors.npz" # TMP override
33
+ full_data = np.load(data_path, allow_pickle=True)
34
+ features = full_data["features"]
35
+ labels = full_data["labels"]
36
+ sets = full_data["sets"]
37
+
38
+ # Handle inf/nan features: instead of dropping columns, zero-out entire affected columns
39
+ # so that VarianceThreshold will remove them later, keeping indices aligned.
40
+ bad_entries = np.isinf(features) | np.isnan(features)
41
+ bad_cols = np.any(bad_entries, axis=0)
42
+ if np.any(bad_cols):
43
+ features[:, bad_cols] = 0.0
44
+
45
+ train_val_mask = sets != "test" # TMP fix should be "validation" ?
46
+ train_X = features[train_val_mask]
47
+ train_y = labels[train_val_mask]
48
+ test_mask = sets == "test"
49
+ val_X = features[test_mask]
50
+ val_y = labels[test_mask]
51
+
52
+ task_config = {
53
+ "NR-AR": {
54
+ "max_depth": 4,
55
+ "min_child_weight": 1.1005779061921914,
56
+ "gamma": 0.1317988706679324,
57
+ "learning_rate": 0.039645108160965156,
58
+ "subsample": 0.7296241662412439,
59
+ "colsample_bytree": 0.8021365422870282,
60
+ "reg_alpha": 3.3237336705963336e-06,
61
+ "reg_lambda": 0.5602005185114373,
62
+ "colsample_bylevel": 0.6436881915714322,
63
+ "max_bin": 320,
64
+ "grow_policy": "depthwise",
65
+ "var_threshold": 0.007666987709838448
66
+ },
67
+ "NR-AR-LBD": {
68
+ "max_depth": 4,
69
+ "min_child_weight": 4.1987212703698695,
70
+ "gamma": 1.2762015931613548,
71
+ "learning_rate": 0.15154599977311695,
72
+ "subsample": 0.6695940698634157,
73
+ "colsample_bytree": 0.7739932636137854,
74
+ "reg_alpha": 0.07898626960219088,
75
+ "reg_lambda": 8.571012949754111,
76
+ "colsample_bylevel": 0.9853057670318977,
77
+ "max_bin": 512,
78
+ "grow_policy": "lossguide",
79
+ "var_threshold": 0.00037667540735397795
80
+ },
81
+ "NR-AhR": {
82
+ "max_depth": 5,
83
+ "min_child_weight": 6.689827023187083,
84
+ "gamma": 0.05246277760115231,
85
+ "learning_rate": 0.04756606141238733,
86
+ "subsample": 0.8679211962117436,
87
+ "colsample_bytree": 0.6095873089337578,
88
+ "reg_alpha": 2.9267916989096844e-05,
89
+ "reg_lambda": 0.16597411475484836,
90
+ "colsample_bylevel": 0.6109587378961451,
91
+ "max_bin": 192,
92
+ "grow_policy": "lossguide",
93
+ "var_threshold": 0.006450426707708987
94
+ },
95
+ "NR-Aromatase": {
96
+ "max_depth": 3,
97
+ "min_child_weight": 3.2876314247596152,
98
+ "gamma": 0.19699266508924895,
99
+ "learning_rate": 0.05088088932843542,
100
+ "subsample": 0.7865649204014827,
101
+ "colsample_bytree": 0.7251861382401115,
102
+ "reg_alpha": 1.5663141562519894e-05,
103
+ "reg_lambda": 0.8079227014059855,
104
+ "colsample_bylevel": 0.6264563203168154,
105
+ "max_bin": 320,
106
+ "grow_policy": "lossguide",
107
+ "var_threshold": 0.008210794229202779
108
+ },
109
+ "NR-ER": {
110
+ "max_depth": 4,
111
+ "min_child_weight": 5.780102015649284,
112
+ "gamma": 1.4129142474001934,
113
+ "learning_rate": 0.030962338755374925,
114
+ "subsample": 0.6495287204129598,
115
+ "colsample_bytree": 0.6052286799267346,
116
+ "reg_alpha": 2.350761568396455e-08,
117
+ "reg_lambda": 0.09630529926179951,
118
+ "colsample_bylevel": 0.7431813327243276,
119
+ "max_bin": 384,
120
+ "grow_policy": "lossguide",
121
+ "var_threshold": 0.0023810780862365695
122
+ },
123
+ "NR-ER-LBD": {
124
+ "max_depth": 5,
125
+ "min_child_weight": 9.173052917805649,
126
+ "gamma": 1.0722539699322629,
127
+ "learning_rate": 0.04237749698413915,
128
+ "subsample": 0.7066072339657229,
129
+ "colsample_bytree": 0.6813795582720684,
130
+ "reg_alpha": 0.00023207537137377197,
131
+ "reg_lambda": 15.088634424806914,
132
+ "colsample_bylevel": 0.7799437417755278,
133
+ "max_bin": 384,
134
+ "grow_policy": "depthwise",
135
+ "var_threshold": 0.0019169350680113165
136
+ },
137
+ "NR-PPAR-gamma": {
138
+ "max_depth": 6,
139
+ "min_child_weight": 5.174007598815524,
140
+ "gamma": 1.9912192366255241,
141
+ "learning_rate": 0.05540828755212913,
142
+ "subsample": 0.6903953157523113,
143
+ "colsample_bytree": 0.8663027348173384,
144
+ "reg_alpha": 2.083339410970234e-08,
145
+ "reg_lambda": 0.015396790332761562,
146
+ "colsample_bylevel": 0.9751745752733803,
147
+ "max_bin": 320,
148
+ "grow_policy": "lossguide",
149
+ "var_threshold": 0.0029616070252124786
150
+ },
151
+ "SR-ARE": {
152
+ "max_depth": 7,
153
+ "min_child_weight": 9.1659526731455,
154
+ "gamma": 0.697265411436678,
155
+ "learning_rate": 0.06570769871964029,
156
+ "subsample": 0.9905868520803529,
157
+ "colsample_bytree": 0.9320468198902392,
158
+ "reg_alpha": 0.0015832053017691588,
159
+ "reg_lambda": 0.05920338550334178,
160
+ "colsample_bylevel": 0.9881491817036743,
161
+ "max_bin": 128,
162
+ "grow_policy": "lossguide",
163
+ "var_threshold": 0.002817440527458996
164
+ },
165
+ "SR-ATAD5": {
166
+ "max_depth": 8,
167
+ "min_child_weight": 3.840348891355251,
168
+ "gamma": 1.6154505675458388,
169
+ "learning_rate": 0.13247082849598005,
170
+ "subsample": 0.8051455662822469,
171
+ "colsample_bytree": 0.8812075918541051,
172
+ "reg_alpha": 1.0831755964182738e-08,
173
+ "reg_lambda": 27.095693383578947,
174
+ "colsample_bylevel": 0.636617995280427,
175
+ "max_bin": 256,
176
+ "grow_policy": "depthwise",
177
+ "var_threshold": 0.009669430411280284
178
+ },
179
+ "SR-HSE": {
180
+ "max_depth": 9,
181
+ "min_child_weight": 6.413184249228777,
182
+ "gamma": 1.033704331418744,
183
+ "learning_rate": 0.05274739499143931,
184
+ "subsample": 0.8865620043291726,
185
+ "colsample_bytree": 0.6816866072800449,
186
+ "reg_alpha": 0.058835365152010946,
187
+ "reg_lambda": 0.020754661410877756,
188
+ "colsample_bylevel": 0.9110208090854688,
189
+ "max_bin": 512,
190
+ "grow_policy": "lossguide",
191
+ "var_threshold": 0.005674926071804129
192
+ },
193
+ "SR-MMP": {
194
+ "max_depth": 5,
195
+ "min_child_weight": 9.817728618387365,
196
+ "gamma": 1.174192311657815,
197
+ "learning_rate": 0.0469463693712702,
198
+ "subsample": 0.7551958380501903,
199
+ "colsample_bytree": 0.7909988895785574,
200
+ "reg_alpha": 0.00015815798249652454,
201
+ "reg_lambda": 0.07975430070894152,
202
+ "colsample_bylevel": 0.6649592956153568,
203
+ "max_bin": 128,
204
+ "grow_policy": "depthwise",
205
+ "var_threshold": 0.006024127982297082
206
+ },
207
+ "SR-p53": {
208
+ "max_depth": 8,
209
+ "min_child_weight": 5.038486734836349,
210
+ "gamma": 1.807085258740345,
211
+ "learning_rate": 0.1096533837056875,
212
+ "subsample": 0.71588646279992,
213
+ "colsample_bytree": 0.8086559814485024,
214
+ "reg_alpha": 3.864250735509029e-08,
215
+ "reg_lambda": 0.03548737332001143,
216
+ "colsample_bylevel": 0.7740614694930106,
217
+ "max_bin": 128,
218
+ "grow_policy": "depthwise",
219
+ "var_threshold": 0.008637178477182731
220
+ },
221
+ }
222
+
223
+ results = {}
224
+
225
+ for i, task in enumerate(task_config.keys()):
226
+ npos = np.nansum(train_y[:, i])
227
+ nneg = np.sum(~np.isnan(train_y[:, i])) - npos
228
+ task_config[task].update({
229
+ "tree_method": "hist",
230
+ "n_estimators": 10_000,
231
+ "early_stopping_rounds": 50,
232
+ "eval_metric": "auc",
233
+ "scale_pos_weight": nneg / max(npos, 1),
234
+ "device": "cpu",
235
+ })
236
+
237
+ model = Tox21XGBClassifier(seed=SEED, task_configs=task_config)
238
+
239
+ print("Start training.")
240
+ for i, task in enumerate(model.tasks):
241
+
242
+ #print(model.model[task])
243
+
244
+ # Training -----------------------
245
+ task_labels = train_y[:, i]
246
+ label_mask = ~np.isnan(task_labels)
247
+ task_data = train_X[label_mask]
248
+ task_labels = task_labels[label_mask].astype(int)
249
+
250
+ # Remove low variance features and scale
251
+ var_thresh = VarianceThreshold(threshold=task_config[task]["var_threshold"])
252
+ task_data = var_thresh.fit_transform(task_data)
253
+ scaler = StandardScaler()
254
+ task_data = scaler.fit_transform(task_data)
255
+ model.feature_processors[task] = {
256
+ "selector": var_thresh,
257
+ "scaler": scaler,
258
+ }
259
+
260
+
261
+ # From X_train split 10% for an early stopping validation set
262
+ np.random.seed(SEED)
263
+ random_numbers = np.random.rand(task_data.shape[0])
264
+ es_val_mask = random_numbers < 0.1
265
+ es_train_mask = random_numbers >= 0.1
266
+ X_es_val, y_es_val = task_data[es_val_mask], task_labels[es_val_mask]
267
+ X_es_train, y_es_train = task_data[es_train_mask], task_labels[es_train_mask]
268
+
269
+ print(f"Fit task {task} using {sum(label_mask)} samples and {task_data.shape[1]} features")
270
+ model.fit(task, X_es_train, y_es_train, eval_set=[(X_es_val, y_es_val)], verbose=False)
271
+
272
+ # Evaluation -----------------------
273
+ val_task_labels = val_y[:, i]
274
+ val_label_mask = ~np.isnan(val_task_labels)
275
+ val_task_labels = val_task_labels[val_label_mask].astype(int)
276
+ val_task_data = val_X[val_label_mask]
277
+ val_task_data = model.feature_processors[task]["selector"].transform(val_task_data)
278
+ val_task_data = model.feature_processors[task]["scaler"].transform(val_task_data)
279
+
280
+ # Evaluate model
281
+ pred = model.predict(task, val_task_data)
282
+ results[task] = [roc_auc_score(y_true=val_task_labels, y_score=pred)]
283
+
284
+ print(f"Save model under {args.model_dir}")
285
+ model.save_model(args.model_dir)
286
+
287
+ print("Results:")
288
+ print(tabulate(results, headers="keys"))
289
+ print("Average: ", sum([val[0] for val in results.values()]) / len(results))
290
+
291
+
292
+ if __name__ == "__main__":
293
+ args = parser.parse_args()
294
+ main(args)
src/utils.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## These MolStandardizer classes are due to Paolo Tosco
2
+ ## It was taken from the FS-Mol github
3
+ ## (https://github.com/microsoft/FS-Mol/blob/main/fs_mol/preprocessing/utils/
4
+ ## standardizer.py)
5
+ ## They ensure that a sequence of standardization operations are applied
6
+ ## https://gist.github.com/ptosco/7e6b9ab9cc3e44ba0919060beaed198e
7
+
8
+ import os
9
+ import pickle
10
+
11
+ from rdkit import Chem
12
+ from rdkit.Chem.MolStandardize import rdMolStandardize
13
+
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+
16
+ TASKS = [
17
+ "NR-AR",
18
+ "NR-AR-LBD",
19
+ "NR-AhR",
20
+ "NR-Aromatase",
21
+ "NR-ER",
22
+ "NR-ER-LBD",
23
+ "NR-PPAR-gamma",
24
+ "SR-ARE",
25
+ "SR-ATAD5",
26
+ "SR-HSE",
27
+ "SR-MMP",
28
+ "SR-p53",
29
+ ]
30
+
31
+ KNOWN_DESCR = ["ecfps", "rdkit_descr_quantiles", "maccs", "tox"]
32
+
33
+ USED_200_DESCR = [
34
+ 0,
35
+ 1,
36
+ 2,
37
+ 3,
38
+ 4,
39
+ 5,
40
+ 6,
41
+ 7,
42
+ 8,
43
+ 9,
44
+ 10,
45
+ 11,
46
+ 12,
47
+ 13,
48
+ 14,
49
+ 15,
50
+ 16,
51
+ 25,
52
+ 26,
53
+ 27,
54
+ 28,
55
+ 29,
56
+ 30,
57
+ 31,
58
+ 32,
59
+ 33,
60
+ 34,
61
+ 35,
62
+ 36,
63
+ 37,
64
+ 38,
65
+ 39,
66
+ 40,
67
+ 41,
68
+ 42,
69
+ 43,
70
+ 44,
71
+ 45,
72
+ 46,
73
+ 47,
74
+ 48,
75
+ 49,
76
+ 50,
77
+ 51,
78
+ 52,
79
+ 53,
80
+ 54,
81
+ 55,
82
+ 56,
83
+ 57,
84
+ 58,
85
+ 59,
86
+ 60,
87
+ 61,
88
+ 62,
89
+ 63,
90
+ 64,
91
+ 65,
92
+ 66,
93
+ 67,
94
+ 68,
95
+ 69,
96
+ 70,
97
+ 71,
98
+ 72,
99
+ 73,
100
+ 74,
101
+ 75,
102
+ 76,
103
+ 77,
104
+ 78,
105
+ 79,
106
+ 80,
107
+ 81,
108
+ 82,
109
+ 83,
110
+ 84,
111
+ 85,
112
+ 86,
113
+ 87,
114
+ 88,
115
+ 89,
116
+ 90,
117
+ 91,
118
+ 92,
119
+ 93,
120
+ 94,
121
+ 95,
122
+ 96,
123
+ 97,
124
+ 98,
125
+ 99,
126
+ 100,
127
+ 101,
128
+ 102,
129
+ 103,
130
+ 104,
131
+ 105,
132
+ 106,
133
+ 107,
134
+ 108,
135
+ 109,
136
+ 110,
137
+ 111,
138
+ 112,
139
+ 113,
140
+ 114,
141
+ 115,
142
+ 116,
143
+ 117,
144
+ 118,
145
+ 119,
146
+ 120,
147
+ 121,
148
+ 122,
149
+ 123,
150
+ 124,
151
+ 125,
152
+ 126,
153
+ 127,
154
+ 128,
155
+ 129,
156
+ 130,
157
+ 131,
158
+ 132,
159
+ 133,
160
+ 134,
161
+ 135,
162
+ 136,
163
+ 137,
164
+ 138,
165
+ 139,
166
+ 140,
167
+ 141,
168
+ 142,
169
+ 143,
170
+ 144,
171
+ 145,
172
+ 146,
173
+ 147,
174
+ 148,
175
+ 149,
176
+ 150,
177
+ 151,
178
+ 152,
179
+ 153,
180
+ 154,
181
+ 155,
182
+ 156,
183
+ 157,
184
+ 158,
185
+ 159,
186
+ 160,
187
+ 161,
188
+ 162,
189
+ 163,
190
+ 164,
191
+ 165,
192
+ 166,
193
+ 167,
194
+ 168,
195
+ 169,
196
+ 170,
197
+ 171,
198
+ 172,
199
+ 173,
200
+ 174,
201
+ 175,
202
+ 176,
203
+ 177,
204
+ 178,
205
+ 179,
206
+ 180,
207
+ 181,
208
+ 182,
209
+ 183,
210
+ 184,
211
+ 185,
212
+ 186,
213
+ 187,
214
+ 188,
215
+ 189,
216
+ 190,
217
+ 191,
218
+ 192,
219
+ 193,
220
+ 194,
221
+ 195,
222
+ 196,
223
+ 197,
224
+ 198,
225
+ 199,
226
+ 200,
227
+ 201,
228
+ 202,
229
+ 203,
230
+ 204,
231
+ 205,
232
+ 206,
233
+ 207,
234
+ ]
235
+
236
+
237
+ class Standardizer:
238
+ """
239
+ Simple wrapper class around rdkit Standardizer.
240
+ """
241
+
242
+ DEFAULT_CANON_TAUT = False
243
+ DEFAULT_METAL_DISCONNECT = False
244
+ MAX_TAUTOMERS = 100
245
+ MAX_TRANSFORMS = 100
246
+ MAX_RESTARTS = 200
247
+ PREFER_ORGANIC = True
248
+
249
+ def __init__(
250
+ self,
251
+ metal_disconnect=None,
252
+ canon_taut=None,
253
+ ):
254
+ """
255
+ Constructor.
256
+ All parameters are optional.
257
+ :param metal_disconnect: if True, metallorganic complexes are
258
+ disconnected
259
+ :param canon_taut: if True, molecules are converted to their
260
+ canonical tautomer
261
+ """
262
+ super().__init__()
263
+ if metal_disconnect is None:
264
+ metal_disconnect = self.DEFAULT_METAL_DISCONNECT
265
+ if canon_taut is None:
266
+ canon_taut = self.DEFAULT_CANON_TAUT
267
+ self._canon_taut = canon_taut
268
+ self._metal_disconnect = metal_disconnect
269
+ self._taut_enumerator = None
270
+ self._uncharger = None
271
+ self._lfrag_chooser = None
272
+ self._metal_disconnector = None
273
+ self._normalizer = None
274
+ self._reionizer = None
275
+ self._params = None
276
+
277
+ @property
278
+ def params(self):
279
+ """Return the MolStandardize CleanupParameters."""
280
+ if self._params is None:
281
+ self._params = rdMolStandardize.CleanupParameters()
282
+ self._params.maxTautomers = self.MAX_TAUTOMERS
283
+ self._params.maxTransforms = self.MAX_TRANSFORMS
284
+ self._params.maxRestarts = self.MAX_RESTARTS
285
+ self._params.preferOrganic = self.PREFER_ORGANIC
286
+ self._params.tautomerRemoveSp3Stereo = False
287
+ return self._params
288
+
289
+ @property
290
+ def canon_taut(self):
291
+ """Return whether tautomer canonicalization will be done."""
292
+ return self._canon_taut
293
+
294
+ @property
295
+ def metal_disconnect(self):
296
+ """Return whether metallorganic complexes will be disconnected."""
297
+ return self._metal_disconnect
298
+
299
+ @property
300
+ def taut_enumerator(self):
301
+ """Return the TautomerEnumerator object."""
302
+ if self._taut_enumerator is None:
303
+ self._taut_enumerator = rdMolStandardize.TautomerEnumerator(self.params)
304
+ return self._taut_enumerator
305
+
306
+ @property
307
+ def uncharger(self):
308
+ """Return the Uncharger object."""
309
+ if self._uncharger is None:
310
+ self._uncharger = rdMolStandardize.Uncharger()
311
+ return self._uncharger
312
+
313
+ @property
314
+ def lfrag_chooser(self):
315
+ """Return the LargestFragmentChooser object."""
316
+ if self._lfrag_chooser is None:
317
+ self._lfrag_chooser = rdMolStandardize.LargestFragmentChooser(
318
+ self.params.preferOrganic
319
+ )
320
+ return self._lfrag_chooser
321
+
322
+ @property
323
+ def metal_disconnector(self):
324
+ """Return the MetalDisconnector object."""
325
+ if self._metal_disconnector is None:
326
+ self._metal_disconnector = rdMolStandardize.MetalDisconnector()
327
+ return self._metal_disconnector
328
+
329
+ @property
330
+ def normalizer(self):
331
+ """Return the Normalizer object."""
332
+ if self._normalizer is None:
333
+ self._normalizer = rdMolStandardize.Normalizer(
334
+ self.params.normalizationsFile, self.params.maxRestarts
335
+ )
336
+ return self._normalizer
337
+
338
+ @property
339
+ def reionizer(self):
340
+ """Return the Reionizer object."""
341
+ if self._reionizer is None:
342
+ self._reionizer = rdMolStandardize.Reionizer(self.params.acidbaseFile)
343
+ return self._reionizer
344
+
345
+ def charge_parent(self, mol_in):
346
+ """Sequentially apply a series of MolStandardize operations:
347
+ * MetalDisconnector
348
+ * Normalizer
349
+ * Reionizer
350
+ * LargestFragmentChooser
351
+ * Uncharger
352
+ The net result is that a desalted, normalized, neutral
353
+ molecule with implicit Hs is returned.
354
+ """
355
+ params = Chem.RemoveHsParameters()
356
+ params.removeAndTrackIsotopes = True
357
+ mol_in = Chem.RemoveHs(mol_in, params, sanitize=False)
358
+ if self._metal_disconnect:
359
+ mol_in = self.metal_disconnector.Disconnect(mol_in)
360
+ normalized = self.normalizer.normalize(mol_in)
361
+ Chem.SanitizeMol(normalized)
362
+ normalized = self.reionizer.reionize(normalized)
363
+ Chem.AssignStereochemistry(normalized)
364
+ normalized = self.lfrag_chooser.choose(normalized)
365
+ normalized = self.uncharger.uncharge(normalized)
366
+ # need this to reassess aromaticity on things like
367
+ # cyclopentadienyl, tropylium, azolium, etc.
368
+ Chem.SanitizeMol(normalized)
369
+ return Chem.RemoveHs(Chem.AddHs(normalized))
370
+
371
+ def standardize_mol(self, mol_in):
372
+ """
373
+ Standardize a single molecule.
374
+ :param mol_in: a Chem.Mol
375
+ :return: * (standardized Chem.Mol, n_taut) tuple
376
+ if success. n_taut will be negative if
377
+ tautomer enumeration was aborted due
378
+ to reaching a limit
379
+ * (None, error_msg) if failure
380
+ This calls self.charge_parent() and, if self._canon_taut
381
+ is True, runs tautomer canonicalization.
382
+ """
383
+ n_tautomers = 0
384
+ if isinstance(mol_in, Chem.Mol):
385
+ name = None
386
+ try:
387
+ name = mol_in.GetProp("_Name")
388
+ except KeyError:
389
+ pass
390
+ if not name:
391
+ name = "NONAME"
392
+ else:
393
+ error = f"Expected SMILES or Chem.Mol as input, got {str(type(mol_in))}"
394
+ return None, error
395
+ try:
396
+ mol_out = self.charge_parent(mol_in)
397
+ except Exception as e:
398
+ error = f"charge_parent FAILED: {str(e).strip()}"
399
+ return None, error
400
+ if self._canon_taut:
401
+ try:
402
+ res = self.taut_enumerator.Enumerate(mol_out, False)
403
+ except TypeError:
404
+ # we are still on the pre-2021 RDKit API
405
+ res = self.taut_enumerator.Enumerate(mol_out)
406
+ except Exception as e:
407
+ # something else went wrong
408
+ error = f"canon_taut FAILED: {str(e).strip()}"
409
+ return None, error
410
+ n_tautomers = len(res)
411
+ if hasattr(res, "status"):
412
+ completed = (
413
+ res.status == rdMolStandardize.TautomerEnumeratorStatus.Completed
414
+ )
415
+ else:
416
+ # we are still on the pre-2021 RDKit API
417
+ completed = len(res) < 1000
418
+ if not completed:
419
+ n_tautomers = -n_tautomers
420
+ try:
421
+ mol_out = self.taut_enumerator.PickCanonical(res)
422
+ except AttributeError:
423
+ # we are still on the pre-2021 RDKit API
424
+ mol_out = max(
425
+ [(self.taut_enumerator.ScoreTautomer(m), m) for m in res]
426
+ )[1]
427
+ except Exception as e:
428
+ # something else went wrong
429
+ error = f"canon_taut FAILED: {str(e).strip()}"
430
+ return None, error
431
+ mol_out.SetProp("_Name", name)
432
+ return mol_out, n_tautomers