antoniaebner commited on
Commit
b0daa87
·
1 Parent(s): 06a41f1
.gitignore DELETED
@@ -1 +0,0 @@
1
- __pycache__/
 
 
README.md DELETED
@@ -1,103 +0,0 @@
1
- ---
2
- title: Tox21 XGBoost Classifier
3
- emoji: 🚀
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: docker
7
- pinned: false
8
- license: apache-2.0
9
- short_description: XGBoost baseline classifier for Tox21
10
- ---
11
-
12
- # Tox21 XGBoost Classifier
13
-
14
- This repository hosts a Hugging Face Space that provides an examplary API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/tschouis/tox21_leaderboard).
15
-
16
- In this example, we train a XGBoost classifier on the Tox21 targets and save the trained model in the `assets/` folder.
17
-
18
- **Important:** For leaderboard submission, your Space does not need to include training code. It only needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a prediction dictionary as output, with SMILES and targets as keys. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.
19
-
20
- # Repository Structure
21
- - `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
22
- - `app.py` - FastAPI application wrapper (can be used as-is).
23
-
24
- - `src/` - Core model & preprocessing logic:
25
- - `data.py` - SMILES preprocessing pipeline
26
- - `model.py` - XGBoost classifier wrapper
27
- - `train.py` - Script to train the classifier
28
- - `utils.py` – Constants and Helper functions
29
-
30
- # Quickstart with Spaces
31
-
32
- You can easily adapt this project in your own Hugging Face account:
33
-
34
- - Open this Space on Hugging Face.
35
-
36
- - Click "Duplicate this Space" (top-right corner).
37
-
38
- - Modify `src/` for your preprocessing pipeline and model class
39
-
40
- - Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.
41
-
42
- That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.
43
-
44
- # Installation
45
- To run (and train) the XGBoost, clone the repository and install dependencies:
46
-
47
- ```bash
48
- git clone https://huggingface.co/spaces/tschouis/tox21_xgboost_classifier
49
- cd tox_21_xgb_classifier
50
-
51
- conda create -n tox21_xgb_cls python=3.11
52
- conda activate tox21_xgb_cls
53
- pip install -r requirements.txt
54
- ```
55
-
56
- # Training
57
-
58
- To train the XGBoost model from scratch:
59
-
60
- ```bash
61
- python -m src/train.py
62
- ```
63
-
64
- This will:
65
-
66
- 1. Load and preprocess the Tox21 training dataset.
67
- 2. Train a XGBoost classifier.
68
- 3. Save the trained model to the assets/ folder.
69
- 4. Evaluate the trained XGBoost classifier on the validation split.
70
-
71
-
72
- # Inference
73
-
74
- For inference, you only need `predict.py`.
75
-
76
- Example usage inside Python:
77
-
78
- ```python
79
- from predict import predict
80
-
81
- smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
82
- results = predict(smiles_list)
83
-
84
- print(results)
85
- ```
86
-
87
- The output will be a nested dictionary in the format:
88
-
89
- ```python
90
- {
91
- "CCO": {"target1": 0, "target2": 1, ..., "target12": 0},
92
- "c1ccccc1": {"target1": 1, "target2": 0, ..., "target12": 1},
93
- "CC(=O)O": {"target1": 0, "target2": 0, ..., "target12": 0}
94
- }
95
- ```
96
-
97
- # Notes
98
-
99
- - Only adapting `predict.py` for your model inference is required for leaderboard submission.
100
-
101
- - Training (`src/train.py`) is provided for reproducibility.
102
-
103
- - Preprocessing (here inside `src/data.py`) must be applied at inference time, not just training.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py DELETED
@@ -1,78 +0,0 @@
1
- """
2
- This is the main entry point for the FastAPI application.
3
- The app handles the request to predict toxicity for a list of SMILES strings.
4
- """
5
-
6
- # ---------------------------------------------------------------------------------------
7
- # Dependencies and global variable definition
8
- import os
9
- from typing import List, Dict, Optional
10
- from fastapi import FastAPI, Header, HTTPException
11
- from pydantic import BaseModel, Field
12
-
13
- from predict import predict as predict_func
14
-
15
- API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
-
17
-
18
- # ---------------------------------------------------------------------------------------
19
- class Request(BaseModel):
20
- smiles: List[str] = Field(min_items=1, max_items=1000)
21
-
22
-
23
- class Response(BaseModel):
24
- predictions: dict
25
- model_info: Dict[str, str] = {}
26
-
27
-
28
- app = FastAPI(title="toxicity-api")
29
-
30
-
31
- @app.get("/")
32
- def root():
33
- return {
34
- "message": "Toxicity Prediction API",
35
- "endpoints": {
36
- "/metadata": "GET - API metadata and capabilities",
37
- "/healthz": "GET - Health check",
38
- "/predict": "POST - Predict toxicity for SMILES",
39
- },
40
- "usage": "Send POST to /predict with {'smiles': ['your_smiles_here']} and Authorization header",
41
- }
42
-
43
-
44
- @app.get("/metadata")
45
- def metadata():
46
- return {
47
- "name": "AwesomeTox",
48
- "version": "1.0.0",
49
- "max_batch_size": 256,
50
- "tox_endpoints": [
51
- "NR-AR",
52
- "NR-AR-LBD",
53
- "NR-AhR",
54
- "NR-Aromatase",
55
- "NR-ER",
56
- "NR-ER-LBD",
57
- "NR-PPAR-gamma",
58
- "SR-ARE",
59
- "SR-ATAD5",
60
- "SR-HSE",
61
- "SR-MMP",
62
- "SR-p53",
63
- ],
64
- }
65
-
66
-
67
- @app.get("/healthz")
68
- def healthz():
69
- return {"ok": True}
70
-
71
-
72
- @app.post("/predict", response_model=Response)
73
- def predict(request: Request):
74
- predictions = predict_func(request.smiles)
75
- return {
76
- "predictions": predictions,
77
- "model_info": {"name": "random_clf", "version": "1.0.0"},
78
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assets/ecdfs.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9d1f0b5753af1e5aa697bd3e0fc4155d6a96bdfd083139f96ce140cb3d47f127
3
- size 37660397
 
 
 
 
assets/scaler.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4b11f9ff54fd099e05a8423dcb2f3cf059c6bdfca9d068de46bf0ce0f727e136
3
- size 78256
 
 
 
 
assets/tox_smarts.json DELETED
The diff for this file is too large to render. See raw diff
 
assets/xgb_alltasks.joblib DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c987ccf417df7c3e458512ffb71d3c052efd5f091f426299644544a8971b2bb6
3
- size 34793787
 
 
 
 
assets_old/ecdfs.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9d1f0b5753af1e5aa697bd3e0fc4155d6a96bdfd083139f96ce140cb3d47f127
3
- size 37660397
 
 
 
 
assets_old/scaler.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cec4650acf6ecc7dd7b820459acc6b6a1bc1f78852ee2328798d6754465c95d0
3
- size 54415
 
 
 
 
assets_old/xgb_alltasks.joblib DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:22748188c9bffbdd15febc4caf2daf9d00d660670025fc5b4371aaf36a0e8fea
3
- size 19718840
 
 
 
 
data/ecdfs.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:782174f4353d7342d9d74ad672f00aeeeef57f35a4dad4b3e20b08a35079adf7
3
- size 33743597
 
 
 
 
predict.py DELETED
@@ -1,64 +0,0 @@
1
- """
2
- This files includes a predict function for the Tox21.
3
- As an input it takes a list of SMILES and it outputs a nested dictionary with
4
- SMILES and target names as keys.
5
- """
6
-
7
- # ---------------------------------------------------------------------------------------
8
- # Dependencies
9
- from collections import defaultdict
10
-
11
- import numpy as np
12
-
13
- from src.model import Tox21XGBClassifier
14
- from src.preprocess import create_descriptors
15
- from src.utils import load_pickle, KNOWN_DESCR
16
-
17
- # ---------------------------------------------------------------------------------------
18
-
19
-
20
- def predict(smiles_list: list[str]) -> dict[str, dict[str, float]]:
21
- """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for
22
- any molecule that could not be cleaned.
23
-
24
- Args:
25
- smiles_list (list[str]): list of SMILES strings
26
-
27
- Returns:
28
- dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}}
29
- """
30
- print(f"Received {len(smiles_list)} SMILES strings")
31
- # preprocessing pipeline
32
- ecdfs_path = "assets/ecdfs.pkl"
33
- scaler_path = "assets/scaler.pkl"
34
- ecdfs = load_pickle(ecdfs_path)
35
- scaler = load_pickle(scaler_path)
36
- print(f"Loaded ecdfs from {ecdfs_path}")
37
- print(f"Loaded scaler from {scaler_path}")
38
-
39
- descriptors = KNOWN_DESCR
40
- features, mol_mask = create_descriptors(
41
- smiles_list,
42
- ecdfs=ecdfs,
43
- scaler=scaler,
44
- descriptors=descriptors,
45
- )
46
- print(f"Created descriptors {descriptors} for molecules.")
47
- print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning")
48
-
49
- # setup model
50
- model = Tox21XGBClassifier(seed=42)
51
- model_path = "assets/xgb_alltasks.joblib"
52
- model.load_model(model_path)
53
- print(f"Loaded model from {model_path}")
54
-
55
- # make predicitons
56
- predictions = defaultdict(dict)
57
- # create a list with same length as smiles_list to obtain indices for respective features
58
- feat_indices = np.cumsum(mol_mask) - 1
59
-
60
- for target in model.tasks:
61
- target_pred = model.predict(target, features)
62
- for smiles, is_clean, i in zip(smiles_list, mol_mask, feat_indices):
63
- predictions[smiles][target] = float(target_pred[i]) if is_clean else 0.0
64
- return predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,10 +0,0 @@
1
- fastapi
2
- uvicorn[standard]
3
- statsmodels
4
- rdkit
5
- numpy
6
- scikit-learn==1.7.1
7
- joblib
8
- tabulate
9
- datasets
10
- xgboost==3.0.5
 
 
 
 
 
 
 
 
 
 
 
src/__init__.py DELETED
File without changes
src/data.py DELETED
@@ -1,90 +0,0 @@
1
- # pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
2
-
3
- """
4
- This files includes a the data processing for Tox21.
5
- As an input it takes a list of SMILES and it outputs a nested dictionary with
6
- SMILES and target names as keys.
7
- """
8
-
9
- from typing import Iterable, Literal
10
-
11
- import numpy as np
12
- import torch
13
-
14
- from .preprocess import normalize_features
15
-
16
- KNOWN_DESCR = ["ecfps", "rdkit_descr_quantiles", "maccs", "tox"]
17
-
18
-
19
- def get_descriptor_dataset(
20
- data_path: str,
21
- descriptors: Iterable[str] | Literal["all"],
22
- scaler=None,
23
- save_scaler_path: str = "data/scaler.pkl",
24
- verbose=True,
25
- normalize=True,
26
- ):
27
- if descriptors == "all":
28
- descriptors = KNOWN_DESCR
29
-
30
- assert isinstance(descriptors, Iterable), "Passed descriptors are not iterable!"
31
- assert all(
32
- [descr in KNOWN_DESCR for descr in descriptors]
33
- ), f"Passed descriptors contains unknown descriptor types. Allowed descriptors: {KNOWN_DESCR}"
34
-
35
- datafile = np.load(data_path)
36
-
37
- if not isinstance(datafile, np.ndarray):
38
- # concatenate all descriptors and normalize
39
- data = np.concatenate([datafile[descr] for descr in descriptors], axis=1)
40
- labels = datafile["labels"]
41
-
42
- else:
43
- print("NPY file passed, cannot select specific descriptors")
44
- data, labels = datafile[:, :-12], datafile[:, -12:]
45
-
46
- if normalize:
47
- data, scaler = normalize_features(
48
- data,
49
- scaler=scaler,
50
- save_scaler_path=save_scaler_path,
51
- verbose=verbose,
52
- )
53
-
54
- # filter out unsanitized molecules
55
- mask = ~np.isnan(data).any(axis=1)
56
- data = data[mask]
57
- labels = labels[mask]
58
-
59
- assert data.shape[0] == labels.shape[0], (
60
- f"Mismatch between data and labels: "
61
- f"data has {data.shape[0]} samples, but labels has {labels.shape[0]} samples."
62
- )
63
-
64
- return (data, labels, scaler)
65
-
66
-
67
- def get_torch_descriptor_dataset(
68
- data_path: str,
69
- descriptors: list[str],
70
- scaler=None,
71
- save_scaler_path: str = "data/scaler.pkl",
72
- nan_to_num: int = -100,
73
- verbose=True,
74
- normalize=True,
75
- ) -> torch.utils.data.TensorDataset:
76
- data, labels, scaler = get_descriptor_dataset(
77
- data_path,
78
- descriptors,
79
- scaler,
80
- save_scaler_path,
81
- verbose=verbose,
82
- normalize=normalize,
83
- )
84
-
85
- labels = np.nan_to_num(labels, nan=nan_to_num)
86
-
87
- dataset = torch.utils.data.TensorDataset(
88
- torch.FloatTensor(data), torch.LongTensor(labels)
89
- )
90
- return dataset, scaler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model.py DELETED
@@ -1,79 +0,0 @@
1
- """
2
- This files includes a XGBoost model for Tox21.
3
- As an input it takes a list of SMILES and it outputs a nested dictionary with
4
- SMILES and target names as keys.
5
- """
6
-
7
- # ---------------------------------------------------------------------------------------
8
- # Dependencies
9
- import os
10
- import joblib
11
-
12
- import numpy as np
13
- from xgboost import XGBClassifier
14
-
15
- from .utils import TASKS
16
-
17
-
18
- # ---------------------------------------------------------------------------------------
19
- class Tox21XGBClassifier:
20
- """A XGBoost classifier that assigns a toxicity score to a given SMILES string."""
21
-
22
- def __init__(self, seed: int = 42):
23
- """Initialize an XGBoost classifier for each of the 12 Tox21 tasks.
24
-
25
- Args:
26
- seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42.
27
- """
28
- self.tasks = TASKS
29
- self.model = {
30
- task: XGBClassifier(n_estimators=1000, random_state=seed, n_jobs=8)
31
- for task in self.tasks
32
- }
33
-
34
- def load_model(self, path: str) -> None:
35
- """Loads the model from a given path
36
-
37
- Args:
38
- path (str): path to model checkpoint
39
- """
40
- self.model = joblib.load(path)
41
-
42
- def save_model(self, path: str) -> None:
43
- """Saves the model to a given path
44
-
45
- Args:
46
- path (str): path to save model to
47
- """
48
- if not os.path.exists(os.path.dirname(path)):
49
- os.makedirs(os.path.dirname(path))
50
-
51
- joblib.dump(self.model, path)
52
-
53
- def fit(self, task: str, input_features: np.ndarray, labels: np.ndarray) -> None:
54
- """Train XGBoost for a given task
55
-
56
- Args:
57
- task (str): task to train
58
- input_features (np.ndarray): training features
59
- labels (np.ndarray): training labels
60
- """
61
- assert task in self.tasks, f"Unknown task: {task}"
62
- self.model[task].fit(input_features, labels)
63
-
64
- def predict(self, task: str, features: np.ndarray) -> np.ndarray:
65
- """Predicts labels for a given Tox21 target using molecule features
66
-
67
- Args:
68
- task (str): the Tox21 target to predict for
69
- features (np.ndarray): molecule features used for prediction
70
-
71
- Returns:
72
- np.ndarray: predicted probability for positive class
73
- """
74
- assert task in self.tasks, f"Unknown task: {task}"
75
- assert (
76
- len(features.shape) == 2
77
- ), f"Function expects 2D np.array. Current shape: {features.shape}"
78
- preds = self.model[task].predict_proba(features)
79
- return preds[:, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/preprocess.py DELETED
@@ -1,405 +0,0 @@
1
- # pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
2
-
3
- """
4
- This files includes a the data processing for Tox21.
5
- As an input it takes a list of SMILES and it outputs a nested dictionary with
6
- SMILES and target names as keys.
7
- """
8
-
9
- import os
10
- import argparse
11
- import json
12
- from typing import Iterable
13
-
14
- import numpy as np
15
- import pandas as pd
16
-
17
- from sklearn.preprocessing import StandardScaler
18
- from statsmodels.distributions.empirical_distribution import ECDF
19
- from datasets import load_dataset
20
-
21
- from rdkit import Chem, DataStructs
22
- from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
23
- from rdkit.Chem.rdchem import Mol
24
-
25
- from src.utils import (
26
- TASKS,
27
- KNOWN_DESCR,
28
- HF_TOKEN,
29
- USED_200_DESCR,
30
- Standardizer,
31
- load_pickle,
32
- write_pickle,
33
- )
34
-
35
- parser = argparse.ArgumentParser(
36
- description="Data preprocessing script for the Tox21 dataset"
37
- )
38
-
39
- parser.add_argument(
40
- "--save_folder",
41
- type=str,
42
- default="data/",
43
- )
44
-
45
- parser.add_argument(
46
- "--use_hf",
47
- type=int,
48
- default=0,
49
- )
50
-
51
- parser.add_argument(
52
- "--path_ecdfs",
53
- type=str,
54
- default="data/ecdfs.pkl",
55
- )
56
-
57
- parser.add_argument(
58
- "--tox_smarts_filepath",
59
- type=str,
60
- default="data/tox_smarts.json",
61
- )
62
-
63
-
64
- def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
65
- """This function creates cleaned RDKit mol objects from a list of SMILES.
66
-
67
- Args:
68
- smiles (list[str]): list of SMILES
69
-
70
- Returns:
71
- list[Mol]: list of cleaned molecules
72
- np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` at
73
- index `i` could not be cleaned and was removed.
74
- """
75
- sm = Standardizer(canon_taut=True)
76
-
77
- clean_mol_mask = list()
78
- mols = list()
79
- for i, smile in enumerate(smiles):
80
- mol = Chem.MolFromSmiles(smile)
81
- standardized_mol, _ = sm.standardize_mol(mol)
82
- is_cleaned = standardized_mol is not None
83
- clean_mol_mask.append(is_cleaned)
84
- if not is_cleaned:
85
- continue
86
- can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
87
- mols.append(can_mol)
88
-
89
- return mols, np.array(clean_mol_mask)
90
-
91
-
92
- def create_ecfp_fps(mols: list[Mol]) -> np.ndarray:
93
- """This function ECFP fingerprints for a list of molecules.
94
-
95
- Args:
96
- mols (list[Mol]): list of molecules
97
-
98
- Returns:
99
- np.ndarray: ECFP fingerprints of molecules
100
- """
101
- ecfps = list()
102
-
103
- for mol in mols:
104
- fp_sparse_vec = rdFingerprintGenerator.GetCountFPs(
105
- [mol], fpType=rdFingerprintGenerator.MorganFP
106
- )[0]
107
- fp = np.zeros((0,), np.int8)
108
- DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
109
-
110
- ecfps.append(fp)
111
-
112
- return np.array(ecfps)
113
-
114
-
115
- def create_maccs_keys(mols: list[Mol]) -> np.ndarray:
116
- maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols]
117
- return np.array(maccs)
118
-
119
-
120
- def get_tox_patterns(filepath: str):
121
- """This calculates tox features defined in tox_smarts.json.
122
- Args:
123
- mols: A list of Mol
124
- n_jobs: If >1 multiprocessing is used
125
- """
126
- # load patterns
127
- with open(filepath) as f:
128
- smarts_list = [s[1] for s in json.load(f)]
129
-
130
- # Code does not work for this case
131
- assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0
132
-
133
- # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first
134
- # and then use them for all molecules. This gives a huge speedup over existing code.
135
- # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value
136
- all_patterns = []
137
- for smarts in smarts_list:
138
- patterns = [] # list of smarts-patterns
139
- # value for each of the patterns above. Negates the values of the above later.
140
- negations = []
141
-
142
- if " AND " in smarts:
143
- smarts = smarts.split(" AND ")
144
- merge_any = False # If an ' AND ' is found all 'subsmarts' have to match
145
- else:
146
- # If there is an ' OR ' present it's enough is any of the 'subsmarts' match.
147
- # This also accumulates smarts where neither ' OR ' nor ' AND ' occur
148
- smarts = smarts.split(" OR ")
149
- merge_any = True
150
-
151
- # for all subsmarts check if they are preceded by 'NOT '
152
- for s in smarts:
153
- neg = s.startswith("NOT ")
154
- if neg:
155
- s = s[4:]
156
- patterns.append(Chem.MolFromSmarts(s))
157
- negations.append(neg)
158
-
159
- all_patterns.append((patterns, negations, merge_any))
160
- return all_patterns
161
-
162
-
163
- def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray:
164
- """Matches the tox patterns against a molecule. Returns a boolean array"""
165
- tox_data = []
166
- for mol in mols:
167
- mol_features = []
168
- for patts, negations, merge_any in patterns:
169
- matches = [mol.HasSubstructMatch(p) for p in patts]
170
- matches = [m != n for m, n in zip(matches, negations)]
171
- if merge_any:
172
- pres = any(matches)
173
- else:
174
- pres = all(matches)
175
- mol_features.append(pres)
176
-
177
- tox_data.append(np.array(mol_features))
178
-
179
- return np.array(tox_data)
180
-
181
-
182
- def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
183
- """This function creates RDKit descriptors for a list of molecules.
184
-
185
- Args:
186
- mols (list[Mol]): list of molecules
187
-
188
- Returns:
189
- np.ndarray: RDKit descriptors of molecules
190
- """
191
- rdkit_descriptors = list()
192
-
193
- for mol in mols:
194
- descrs = []
195
- for _, descr_calc_fn in Descriptors._descList:
196
- descrs.append(descr_calc_fn(mol))
197
-
198
- descrs = np.array(descrs)
199
- descrs = descrs[USED_200_DESCR]
200
- rdkit_descriptors.append(descrs)
201
-
202
- return np.array(rdkit_descriptors)
203
-
204
-
205
- def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
206
- """Create quantile values for given features using the columns
207
-
208
- Args:
209
- raw_features (np.ndarray): values to put into quantiles
210
- ecdfs (list): ECDFs to use
211
-
212
- Returns:
213
- np.ndarray: computed quantiles
214
- """
215
- quantiles = np.zeros_like(raw_features)
216
-
217
- for column in range(raw_features.shape[1]):
218
- raw_values = raw_features[:, column].reshape(-1)
219
- ecdf = ecdfs[column]
220
- q = ecdf(raw_values)
221
- quantiles[:, column] = q
222
-
223
- return quantiles
224
-
225
-
226
- def fill(features, mask, value=np.nan):
227
- n_mols = len(mask)
228
- n_features = features.shape[1]
229
-
230
- data = np.zeros(shape=(n_mols, n_features))
231
- data.fill(value)
232
- data[~mask] = features
233
- return data
234
-
235
-
236
- def normalize_features(
237
- raw_features,
238
- scaler=None,
239
- save_scaler_path: str = "",
240
- verbose=True,
241
- ):
242
- if scaler is None:
243
- scaler = StandardScaler()
244
- scaler.fit(raw_features)
245
- if verbose:
246
- print("Fitted the StandardScaler")
247
- if save_scaler_path:
248
- write_pickle(save_scaler_path, scaler)
249
- if verbose:
250
- print(f"Saved the StandardScaler under {save_scaler_path}")
251
-
252
- # Normalize feature vectors
253
- normalized_features = scaler.transform(raw_features)
254
- if verbose:
255
- print("Normalized molecule features")
256
- return normalized_features, scaler
257
-
258
-
259
- def create_descriptors(
260
- smiles,
261
- ecdfs=None,
262
- scaler=None,
263
- descriptors: Iterable = KNOWN_DESCR,
264
- ):
265
- # Create cleanded rdkit mol objects
266
- mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
267
- print("Cleaned molecules")
268
-
269
- features = []
270
- if "ecfps" in descriptors:
271
- # Create fingerprints and descriptors
272
- ecfps = create_ecfp_fps(mols)
273
- # expand using mol_mask
274
- ecfps = fill(ecfps, ~clean_mol_mask)
275
- features.append(ecfps)
276
- print("Created ECFP fingerprints")
277
-
278
- if "rdkit_descr_quantiles" in descriptors:
279
- rdkit_descrs = create_rdkit_descriptors(mols)
280
- print("Created RDKit descriptors")
281
-
282
- # Create and save ecdfs
283
- if ecdfs is None:
284
- print("Create ECDFs")
285
- ecdfs = []
286
- for column in range(rdkit_descrs.shape[1]):
287
- raw_values = rdkit_descrs[:, column].reshape(-1)
288
- ecdfs.append(ECDF(raw_values))
289
-
290
- # Create quantiles
291
- rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
292
- # expand using mol_mask
293
- rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
294
- features.append(rdkit_descr_quantiles)
295
- print("Created quantiles of RDKit descriptors")
296
-
297
- if "maccs" in descriptors:
298
- maccs = create_maccs_keys(mols)
299
- maccs = fill(maccs, ~clean_mol_mask)
300
- features.append(maccs)
301
- print("Created MACCS keys")
302
-
303
- if "tox" in descriptors:
304
- tox_patterns = get_tox_patterns("assets/tox_smarts.json")
305
- tox = create_tox_features(mols, tox_patterns)
306
- tox = fill(tox, ~clean_mol_mask)
307
- features.append(tox)
308
- print("Created Tox features")
309
-
310
- # concatenate features
311
- raw_features = np.concatenate(features, axis=1)
312
-
313
- # normalize with scaler if scaler is passed, else create scaler
314
- features, _ = normalize_features(
315
- raw_features,
316
- scaler=scaler,
317
- verbose=True,
318
- )
319
-
320
- return features, clean_mol_mask
321
-
322
-
323
- def main(args):
324
- splits = ["train", "validation"]
325
- ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
326
-
327
- for split in splits:
328
-
329
- print(f"Preprocess {split} molecules")
330
- smiles = list(ds[split]["smiles"])
331
-
332
- # Create cleanded rdkit mol objects
333
- mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
334
- print("Cleaned molecules")
335
-
336
- tox_patterns = get_tox_patterns(args.tox_smarts_filepath)
337
-
338
- # Create fingerprints and descriptors
339
- ecfps = create_ecfp_fps(mols)
340
- # expand using mol_mask
341
- ecfps = fill(ecfps, ~clean_mol_mask)
342
- print("Created ECFP fingerprints")
343
-
344
- rdkit_descrs = create_rdkit_descriptors(mols)
345
- print("Created RDKit descriptors")
346
-
347
- # Create and save ecdfs
348
- if split == "train":
349
- print("Create ECDFs")
350
- ecdfs = []
351
- for column in range(rdkit_descrs.shape[1]):
352
- raw_values = rdkit_descrs[:, column].reshape(-1)
353
- ecdfs.append(ECDF(raw_values))
354
-
355
- write_pickle(args.path_ecdfs, ecdfs)
356
- print(f"Saved ECDFs under {args.path_ecdfs}")
357
- else:
358
- print(f"Load ECDFs from {args.path_ecdfs}")
359
- ecdfs = load_pickle(args.path_ecdfs)
360
-
361
- # Create quantiles
362
- rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
363
- # expand using mol_mask
364
- rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
365
- print("Created quantiles of RDKit descriptors")
366
-
367
- maccs = create_maccs_keys(mols)
368
- maccs = fill(maccs, ~clean_mol_mask)
369
- print("Created MACCS keys")
370
-
371
- tox = create_tox_features(mols, tox_patterns)
372
- tox = fill(tox, ~clean_mol_mask)
373
- print("Created Tox features")
374
-
375
- labels = []
376
- for task in TASKS:
377
- datasplit = ds[split].to_pandas() if args.use_hf else ds[split]
378
- labels.append(datasplit[task].to_numpy())
379
- labels = np.stack(labels, axis=1)
380
-
381
- save_path = os.path.join(args.save_folder, f"tox21_{split}.npz")
382
- with open(save_path, "wb") as f:
383
- np.savez(
384
- f,
385
- labels=labels,
386
- ecfps=ecfps,
387
- rdkit_descr_quantiles=rdkit_descr_quantiles,
388
- maccs=maccs,
389
- tox=tox,
390
- )
391
- print(f"Saved preprocessed {split} split under {save_path}")
392
-
393
- print("Preprocessing finished successfully")
394
-
395
-
396
- if __name__ == "__main__":
397
- args = parser.parse_args()
398
-
399
- if not os.path.exists(args.save_folder):
400
- os.makedirs(args.save_folder)
401
-
402
- if not os.path.exists(os.path.dirname(args.path_ecdfs)):
403
- os.makedirs(os.path.dirname(args.path_ecdfs))
404
-
405
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/push_assets.py DELETED
@@ -1,12 +0,0 @@
1
- from huggingface_hub import HfApi
2
- from .utils import HF_TOKEN
3
-
4
- api = HfApi()
5
-
6
- api.upload_folder(
7
- folder_path="assets/",
8
- path_in_repo="assets",
9
- repo_id="tschouis/tox21_xgboost_classifier",
10
- repo_type="space",
11
- token=HF_TOKEN,
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
src/train.py DELETED
@@ -1,199 +0,0 @@
1
- """
2
- Script for fitting and saving any preprocessing assets, as well as the fitted XGBoost model
3
- """
4
-
5
- import os
6
- import argparse
7
-
8
- import numpy as np
9
-
10
- from tabulate import tabulate
11
- from sklearn.metrics import roc_auc_score
12
-
13
- from .data import get_descriptor_dataset
14
- from .model import Tox21XGBClassifier
15
-
16
- SEED = 42
17
- DATA_FOLDER = "data/"
18
-
19
- parser = argparse.ArgumentParser(description="XGBoost Trainig script for Tox21 dataset")
20
-
21
- parser.add_argument(
22
- "--save_path_model",
23
- type=str,
24
- default="assets/xgb_alltasks.joblib",
25
- )
26
-
27
- parser.add_argument(
28
- "--path_ecdfs",
29
- type=str,
30
- default="assets/ecdfs.pkl",
31
- )
32
-
33
- parser.add_argument(
34
- "--path_scaler",
35
- type=str,
36
- default="assets/scaler.pkl",
37
- )
38
-
39
-
40
- def main(args):
41
- print("Preprocess train molecules")
42
- # load datasets
43
- train_X, train_y, scaler = get_descriptor_dataset(
44
- os.path.join(DATA_FOLDER, "tox21_train.npz"),
45
- descriptors="all",
46
- save_scaler_path="data/scaler.pkl",
47
- )
48
- val_X, val_y, _ = get_descriptor_dataset(
49
- os.path.join(DATA_FOLDER, "tox21_validation.npz"),
50
- descriptors="all",
51
- scaler=scaler,
52
- )
53
-
54
- task_config = {
55
- "NR-AR": {
56
- "colsample_bytree": 0.5,
57
- "learning_rate": 0.05,
58
- "max_depth": 12,
59
- "min_child_weight": 2,
60
- "n_estimators": 1000,
61
- "scale_pos_weight": 80,
62
- "subsample": 0.4,
63
- },
64
- "NR-AR-LBD": {
65
- "colsample_bytree": 0.8,
66
- "learning_rate": 0.04,
67
- "max_depth": 10,
68
- "min_child_weight": 8,
69
- "n_estimators": 1000,
70
- "scale_pos_weight": 10,
71
- "subsample": 0.4,
72
- },
73
- "NR-AhR": {
74
- "colsample_bytree": 0.8,
75
- "learning_rate": 0.05,
76
- "max_depth": 16,
77
- "min_child_weight": 2,
78
- "n_estimators": 1000,
79
- "scale_pos_weight": 80,
80
- "subsample": 1,
81
- },
82
- "NR-Aromatase": {
83
- "colsample_bytree": 0.7,
84
- "learning_rate": 0.05,
85
- "max_depth": 16,
86
- "min_child_weight": 1,
87
- "n_estimators": 1000,
88
- "scale_pos_weight": 50,
89
- "subsample": 0.7,
90
- },
91
- "NR-ER": {
92
- "colsample_bytree": 0.7,
93
- "learning_rate": 0.05,
94
- "max_depth": 10,
95
- "min_child_weight": 4,
96
- "n_estimators": 1000,
97
- "scale_pos_weight": 25,
98
- "subsample": 0.4,
99
- },
100
- "NR-ER-LBD": {
101
- "colsample_bytree": 0.7,
102
- "learning_rate": 0.05,
103
- "max_depth": 16,
104
- "min_child_weight": 4,
105
- "n_estimators": 1000,
106
- "scale_pos_weight": 10,
107
- "subsample": 0.4,
108
- },
109
- "NR-PPAR-gamma": {
110
- "colsample_bytree": 0.8,
111
- "learning_rate": 0.01,
112
- "max_depth": 12,
113
- "min_child_weight": 2,
114
- "n_estimators": 1000,
115
- "scale_pos_weight": 80,
116
- "subsample": 0.4,
117
- },
118
- "SR-ARE": {
119
- "colsample_bytree": 0.7,
120
- "learning_rate": 0.05,
121
- "max_depth": 16,
122
- "min_child_weight": 8,
123
- "n_estimators": 1000,
124
- "scale_pos_weight": 10,
125
- "subsample": 0.7,
126
- },
127
- "SR-ATAD5": {
128
- "colsample_bytree": 0.5,
129
- "learning_rate": 0.02,
130
- "max_depth": 12,
131
- "min_child_weight": 8,
132
- "n_estimators": 1000,
133
- "scale_pos_weight": 10,
134
- "subsample": 0.4,
135
- },
136
- "SR-HSE": {
137
- "colsample_bytree": 0.8,
138
- "learning_rate": 0.02,
139
- "max_depth": 6,
140
- "min_child_weight": 1,
141
- "n_estimators": 1000,
142
- "scale_pos_weight": 25,
143
- "subsample": 1,
144
- },
145
- "SR-MMP": {
146
- "colsample_bytree": 0.5,
147
- "learning_rate": 0.02,
148
- "max_depth": 16,
149
- "min_child_weight": 2,
150
- "n_estimators": 1000,
151
- "scale_pos_weight": 10,
152
- "subsample": 0.7,
153
- },
154
- "SR-p53": {
155
- "colsample_bytree": 0.5,
156
- "learning_rate": 0.02,
157
- "max_depth": 12,
158
- "min_child_weight": 8,
159
- "n_estimators": 1000,
160
- "scale_pos_weight": 10,
161
- "subsample": 0.4,
162
- },
163
- }
164
-
165
- model = Tox21XGBClassifier(seed=42, task_config=task_config)
166
- print("Start training.")
167
- for i, task in enumerate(model.tasks):
168
- task_labels = train_y[:, i]
169
- label_mask = ~np.isnan(task_labels)
170
-
171
- task_data = train_X[label_mask]
172
- task_labels = task_labels[label_mask].astype(int)
173
-
174
- print(f"Fit task {task} using {sum(label_mask)} samples")
175
- model.fit(task, task_data, task_labels)
176
-
177
- print(f"Save model under {args.save_path_model}")
178
- model.save_model(args.save_path_model)
179
-
180
- print("Evaluate model")
181
- results = {}
182
- for i, task in enumerate(model.tasks):
183
- task_labels = val_y[:, i]
184
- label_mask = ~np.isnan(task_labels)
185
-
186
- task_data = val_X[label_mask]
187
- task_labels = task_labels[label_mask].astype(int)
188
-
189
- pred = model.predict(task, task_data)
190
- results[task] = [roc_auc_score(y_true=task_labels, y_score=pred)]
191
-
192
- print("Results:")
193
- print(tabulate(results, headers="keys"))
194
- print("Average: ", sum([val[0] for val in results.values()]) / len(results))
195
-
196
-
197
- if __name__ == "__main__":
198
- args = parser.parse_args()
199
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py DELETED
@@ -1,443 +0,0 @@
1
- ## These MolStandardizer classes are due to Paolo Tosco
2
- ## It was taken from the FS-Mol github
3
- ## (https://github.com/microsoft/FS-Mol/blob/main/fs_mol/preprocessing/utils/
4
- ## standardizer.py)
5
- ## They ensure that a sequence of standardization operations are applied
6
- ## https://gist.github.com/ptosco/7e6b9ab9cc3e44ba0919060beaed198e
7
-
8
- import os
9
- import pickle
10
-
11
- from rdkit import Chem
12
- from rdkit.Chem.MolStandardize import rdMolStandardize
13
-
14
- HF_TOKEN = os.environ.get("HF_TOKEN")
15
-
16
- TASKS = [
17
- "NR-AR",
18
- "NR-AR-LBD",
19
- "NR-AhR",
20
- "NR-Aromatase",
21
- "NR-ER",
22
- "NR-ER-LBD",
23
- "NR-PPAR-gamma",
24
- "SR-ARE",
25
- "SR-ATAD5",
26
- "SR-HSE",
27
- "SR-MMP",
28
- "SR-p53",
29
- ]
30
-
31
- KNOWN_DESCR = ["ecfps", "rdkit_descr_quantiles", "maccs", "tox"]
32
-
33
- USED_200_DESCR = [
34
- 0,
35
- 1,
36
- 2,
37
- 3,
38
- 4,
39
- 5,
40
- 6,
41
- 7,
42
- 8,
43
- 9,
44
- 10,
45
- 11,
46
- 12,
47
- 13,
48
- 14,
49
- 15,
50
- 16,
51
- 25,
52
- 26,
53
- 27,
54
- 28,
55
- 29,
56
- 30,
57
- 31,
58
- 32,
59
- 33,
60
- 34,
61
- 35,
62
- 36,
63
- 37,
64
- 38,
65
- 39,
66
- 40,
67
- 41,
68
- 42,
69
- 43,
70
- 44,
71
- 45,
72
- 46,
73
- 47,
74
- 48,
75
- 49,
76
- 50,
77
- 51,
78
- 52,
79
- 53,
80
- 54,
81
- 55,
82
- 56,
83
- 57,
84
- 58,
85
- 59,
86
- 60,
87
- 61,
88
- 62,
89
- 63,
90
- 64,
91
- 65,
92
- 66,
93
- 67,
94
- 68,
95
- 69,
96
- 70,
97
- 71,
98
- 72,
99
- 73,
100
- 74,
101
- 75,
102
- 76,
103
- 77,
104
- 78,
105
- 79,
106
- 80,
107
- 81,
108
- 82,
109
- 83,
110
- 84,
111
- 85,
112
- 86,
113
- 87,
114
- 88,
115
- 89,
116
- 90,
117
- 91,
118
- 92,
119
- 93,
120
- 94,
121
- 95,
122
- 96,
123
- 97,
124
- 98,
125
- 99,
126
- 100,
127
- 101,
128
- 102,
129
- 103,
130
- 104,
131
- 105,
132
- 106,
133
- 107,
134
- 108,
135
- 109,
136
- 110,
137
- 111,
138
- 112,
139
- 113,
140
- 114,
141
- 115,
142
- 116,
143
- 117,
144
- 118,
145
- 119,
146
- 120,
147
- 121,
148
- 122,
149
- 123,
150
- 124,
151
- 125,
152
- 126,
153
- 127,
154
- 128,
155
- 129,
156
- 130,
157
- 131,
158
- 132,
159
- 133,
160
- 134,
161
- 135,
162
- 136,
163
- 137,
164
- 138,
165
- 139,
166
- 140,
167
- 141,
168
- 142,
169
- 143,
170
- 144,
171
- 145,
172
- 146,
173
- 147,
174
- 148,
175
- 149,
176
- 150,
177
- 151,
178
- 152,
179
- 153,
180
- 154,
181
- 155,
182
- 156,
183
- 157,
184
- 158,
185
- 159,
186
- 160,
187
- 161,
188
- 162,
189
- 163,
190
- 164,
191
- 165,
192
- 166,
193
- 167,
194
- 168,
195
- 169,
196
- 170,
197
- 171,
198
- 172,
199
- 173,
200
- 174,
201
- 175,
202
- 176,
203
- 177,
204
- 178,
205
- 179,
206
- 180,
207
- 181,
208
- 182,
209
- 183,
210
- 184,
211
- 185,
212
- 186,
213
- 187,
214
- 188,
215
- 189,
216
- 190,
217
- 191,
218
- 192,
219
- 193,
220
- 194,
221
- 195,
222
- 196,
223
- 197,
224
- 198,
225
- 199,
226
- 200,
227
- 201,
228
- 202,
229
- 203,
230
- 204,
231
- 205,
232
- 206,
233
- 207,
234
- ]
235
-
236
-
237
- class Standardizer:
238
- """
239
- Simple wrapper class around rdkit Standardizer.
240
- """
241
-
242
- DEFAULT_CANON_TAUT = False
243
- DEFAULT_METAL_DISCONNECT = False
244
- MAX_TAUTOMERS = 100
245
- MAX_TRANSFORMS = 100
246
- MAX_RESTARTS = 200
247
- PREFER_ORGANIC = True
248
-
249
- def __init__(
250
- self,
251
- metal_disconnect=None,
252
- canon_taut=None,
253
- ):
254
- """
255
- Constructor.
256
- All parameters are optional.
257
- :param metal_disconnect: if True, metallorganic complexes are
258
- disconnected
259
- :param canon_taut: if True, molecules are converted to their
260
- canonical tautomer
261
- """
262
- super().__init__()
263
- if metal_disconnect is None:
264
- metal_disconnect = self.DEFAULT_METAL_DISCONNECT
265
- if canon_taut is None:
266
- canon_taut = self.DEFAULT_CANON_TAUT
267
- self._canon_taut = canon_taut
268
- self._metal_disconnect = metal_disconnect
269
- self._taut_enumerator = None
270
- self._uncharger = None
271
- self._lfrag_chooser = None
272
- self._metal_disconnector = None
273
- self._normalizer = None
274
- self._reionizer = None
275
- self._params = None
276
-
277
- @property
278
- def params(self):
279
- """Return the MolStandardize CleanupParameters."""
280
- if self._params is None:
281
- self._params = rdMolStandardize.CleanupParameters()
282
- self._params.maxTautomers = self.MAX_TAUTOMERS
283
- self._params.maxTransforms = self.MAX_TRANSFORMS
284
- self._params.maxRestarts = self.MAX_RESTARTS
285
- self._params.preferOrganic = self.PREFER_ORGANIC
286
- self._params.tautomerRemoveSp3Stereo = False
287
- return self._params
288
-
289
- @property
290
- def canon_taut(self):
291
- """Return whether tautomer canonicalization will be done."""
292
- return self._canon_taut
293
-
294
- @property
295
- def metal_disconnect(self):
296
- """Return whether metallorganic complexes will be disconnected."""
297
- return self._metal_disconnect
298
-
299
- @property
300
- def taut_enumerator(self):
301
- """Return the TautomerEnumerator object."""
302
- if self._taut_enumerator is None:
303
- self._taut_enumerator = rdMolStandardize.TautomerEnumerator(self.params)
304
- return self._taut_enumerator
305
-
306
- @property
307
- def uncharger(self):
308
- """Return the Uncharger object."""
309
- if self._uncharger is None:
310
- self._uncharger = rdMolStandardize.Uncharger()
311
- return self._uncharger
312
-
313
- @property
314
- def lfrag_chooser(self):
315
- """Return the LargestFragmentChooser object."""
316
- if self._lfrag_chooser is None:
317
- self._lfrag_chooser = rdMolStandardize.LargestFragmentChooser(
318
- self.params.preferOrganic
319
- )
320
- return self._lfrag_chooser
321
-
322
- @property
323
- def metal_disconnector(self):
324
- """Return the MetalDisconnector object."""
325
- if self._metal_disconnector is None:
326
- self._metal_disconnector = rdMolStandardize.MetalDisconnector()
327
- return self._metal_disconnector
328
-
329
- @property
330
- def normalizer(self):
331
- """Return the Normalizer object."""
332
- if self._normalizer is None:
333
- self._normalizer = rdMolStandardize.Normalizer(
334
- self.params.normalizationsFile, self.params.maxRestarts
335
- )
336
- return self._normalizer
337
-
338
- @property
339
- def reionizer(self):
340
- """Return the Reionizer object."""
341
- if self._reionizer is None:
342
- self._reionizer = rdMolStandardize.Reionizer(self.params.acidbaseFile)
343
- return self._reionizer
344
-
345
- def charge_parent(self, mol_in):
346
- """Sequentially apply a series of MolStandardize operations:
347
- * MetalDisconnector
348
- * Normalizer
349
- * Reionizer
350
- * LargestFragmentChooser
351
- * Uncharger
352
- The net result is that a desalted, normalized, neutral
353
- molecule with implicit Hs is returned.
354
- """
355
- params = Chem.RemoveHsParameters()
356
- params.removeAndTrackIsotopes = True
357
- mol_in = Chem.RemoveHs(mol_in, params, sanitize=False)
358
- if self._metal_disconnect:
359
- mol_in = self.metal_disconnector.Disconnect(mol_in)
360
- normalized = self.normalizer.normalize(mol_in)
361
- Chem.SanitizeMol(normalized)
362
- normalized = self.reionizer.reionize(normalized)
363
- Chem.AssignStereochemistry(normalized)
364
- normalized = self.lfrag_chooser.choose(normalized)
365
- normalized = self.uncharger.uncharge(normalized)
366
- # need this to reassess aromaticity on things like
367
- # cyclopentadienyl, tropylium, azolium, etc.
368
- Chem.SanitizeMol(normalized)
369
- return Chem.RemoveHs(Chem.AddHs(normalized))
370
-
371
- def standardize_mol(self, mol_in):
372
- """
373
- Standardize a single molecule.
374
- :param mol_in: a Chem.Mol
375
- :return: * (standardized Chem.Mol, n_taut) tuple
376
- if success. n_taut will be negative if
377
- tautomer enumeration was aborted due
378
- to reaching a limit
379
- * (None, error_msg) if failure
380
- This calls self.charge_parent() and, if self._canon_taut
381
- is True, runs tautomer canonicalization.
382
- """
383
- n_tautomers = 0
384
- if isinstance(mol_in, Chem.Mol):
385
- name = None
386
- try:
387
- name = mol_in.GetProp("_Name")
388
- except KeyError:
389
- pass
390
- if not name:
391
- name = "NONAME"
392
- else:
393
- error = f"Expected SMILES or Chem.Mol as input, got {str(type(mol_in))}"
394
- return None, error
395
- try:
396
- mol_out = self.charge_parent(mol_in)
397
- except Exception as e:
398
- error = f"charge_parent FAILED: {str(e).strip()}"
399
- return None, error
400
- if self._canon_taut:
401
- try:
402
- res = self.taut_enumerator.Enumerate(mol_out, False)
403
- except TypeError:
404
- # we are still on the pre-2021 RDKit API
405
- res = self.taut_enumerator.Enumerate(mol_out)
406
- except Exception as e:
407
- # something else went wrong
408
- error = f"canon_taut FAILED: {str(e).strip()}"
409
- return None, error
410
- n_tautomers = len(res)
411
- if hasattr(res, "status"):
412
- completed = (
413
- res.status == rdMolStandardize.TautomerEnumeratorStatus.Completed
414
- )
415
- else:
416
- # we are still on the pre-2021 RDKit API
417
- completed = len(res) < 1000
418
- if not completed:
419
- n_tautomers = -n_tautomers
420
- try:
421
- mol_out = self.taut_enumerator.PickCanonical(res)
422
- except AttributeError:
423
- # we are still on the pre-2021 RDKit API
424
- mol_out = max(
425
- [(self.taut_enumerator.ScoreTautomer(m), m) for m in res]
426
- )[1]
427
- except Exception as e:
428
- # something else went wrong
429
- error = f"canon_taut FAILED: {str(e).strip()}"
430
- return None, error
431
- mol_out.SetProp("_Name", name)
432
- return mol_out, n_tautomers
433
-
434
-
435
- def load_pickle(path: str):
436
- with open(path, "rb") as file:
437
- content = pickle.load(file)
438
- return content
439
-
440
-
441
- def write_pickle(path: str, obj: object):
442
- with open(path, "wb") as file:
443
- pickle.dump(obj, file)