""" This files includes a predict function for the Tox21. As an input it takes a list of SMILES and it outputs a nested dictionary with SMILES and target names as keys. """ # --------------------------------------------------------------------------------------- # Dependencies from collections import defaultdict import numpy as np from src.model import Tox21XGBClassifier from src.preprocess import create_descriptors # --------------------------------------------------------------------------------------- def predict(smiles_list: list[str]) -> dict[str, dict[str, float]]: """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for any molecule that could not be cleaned. Args: smiles_list (list[str]): list of SMILES strings Returns: dict: nested prediction dictionary, following {'': {'': }} """ print(f"Received {len(smiles_list)} SMILES strings") # preprocessing pipeline features, mol_mask = create_descriptors( smiles_list, ) print(f"Created {features.shape[1]} descriptors for the molecules.") print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning. All predictions for these will be set to 0.0.") # setup model model = Tox21XGBClassifier(seed=42) model_dir = "assets/" model.load_model(model_dir) print(f"Loaded model and feature processors from {model_dir}") # make predictions predictions = defaultdict(dict) feat_indices = np.cumsum(mol_mask) - 1 for target in model.tasks: feature_processors = model.feature_processors[target] task_features = feature_processors['selector'].transform(features) task_features = feature_processors['scaler'].transform(task_features) target_pred = model.predict(target, task_features) for smiles, is_clean, i in zip(smiles_list, mol_mask, feat_indices): predictions[smiles][target] = float(target_pred[i]) if is_clean else 0.0 return predictions if __name__ == "__main__": # simple test test_smiles = [ "CCO", "CCN", "invalid_smiles", ] preds = predict(test_smiles) print(preds)