Spaces:
Sleeping
Sleeping
| """ | |
| This files includes a predict function for the Tox21. | |
| As an input it takes a list of SMILES and it outputs a nested dictionary with | |
| SMILES and target names as keys. | |
| """ | |
| # --------------------------------------------------------------------------------------- | |
| # Dependencies | |
| from collections import defaultdict | |
| import numpy as np | |
| from src.model import Tox21XGBClassifier | |
| from src.preprocess import create_descriptors | |
| # --------------------------------------------------------------------------------------- | |
| def predict(smiles_list: list[str]) -> dict[str, dict[str, float]]: | |
| """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for | |
| any molecule that could not be cleaned. | |
| Args: | |
| smiles_list (list[str]): list of SMILES strings | |
| Returns: | |
| dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}} | |
| """ | |
| print(f"Received {len(smiles_list)} SMILES strings") | |
| # preprocessing pipeline | |
| features, mol_mask = create_descriptors( | |
| smiles_list, | |
| ) | |
| print(f"Created {features.shape[1]} descriptors for the molecules.") | |
| print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning. All predictions for these will be set to 0.0.") | |
| # setup model | |
| model = Tox21XGBClassifier(seed=42) | |
| model_dir = "assets/" | |
| model.load_model(model_dir) | |
| print(f"Loaded model and feature processors from {model_dir}") | |
| # make predictions | |
| predictions = defaultdict(dict) | |
| feat_indices = np.cumsum(mol_mask) - 1 | |
| for target in model.tasks: | |
| feature_processors = model.feature_processors[target] | |
| task_features = feature_processors['selector'].transform(features) | |
| task_features = feature_processors['scaler'].transform(task_features) | |
| target_pred = model.predict(target, task_features) | |
| for smiles, is_clean, i in zip(smiles_list, mol_mask, feat_indices): | |
| predictions[smiles][target] = float(target_pred[i]) if is_clean else 0.0 | |
| return predictions | |
| if __name__ == "__main__": | |
| # simple test | |
| test_smiles = [ | |
| "CCO", | |
| "CCN", | |
| "invalid_smiles", | |
| ] | |
| preds = predict(test_smiles) | |
| print(preds) | |