| | import pandas as pd |
| | from tqdm import tqdm |
| | from rdkit import Chem, RDLogger |
| | from datasets import load_dataset |
| | from multiprocessing import Pool, cpu_count |
| | import os |
| |
|
| | |
| | RDLogger.DisableLog('rdApp.*') |
| |
|
| | class SmilesEnumerator: |
| | """ |
| | A simple class to encapsulate the SMILES randomization logic. |
| | Needed for multiprocessing to work correctly with instance methods. |
| | """ |
| | def randomize_smiles(self, smiles): |
| | """Generates a randomized SMILES string.""" |
| | try: |
| | mol = Chem.MolFromSmiles(smiles) |
| | |
| | return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles |
| | except: |
| | |
| | return smiles |
| |
|
| | def create_augmented_pair(smiles_string): |
| | """ |
| | Worker function: takes one SMILES string and returns a tuple |
| | containing two different randomized versions of it. |
| | """ |
| | enumerator = SmilesEnumerator() |
| | smiles_1 = enumerator.randomize_smiles(smiles_string) |
| | smiles_2 = enumerator.randomize_smiles(smiles_string) |
| | return smiles_1, smiles_2 |
| |
|
| | def main(): |
| | """ |
| | Main function to run the parallel data preprocessing. |
| | """ |
| | |
| | |
| | dataset_name = 'jablonkagroup/pubchem-smiles-molecular-formula' |
| | |
| | smiles_column_name = 'smiles' |
| | |
| | output_path = 'data/pubchem_2_epoch_50M' |
| |
|
| | |
| | print(f"Loading dataset '{dataset_name}'...") |
| | |
| | |
| | dataset = load_dataset(dataset_name)['train'].select(range(50_000_000)) |
| | |
| | smiles_list = dataset[smiles_column_name] |
| | print(f"Successfully fetched {len(smiles_list)} SMILES strings.") |
| |
|
| | |
| | |
| | num_workers = cpu_count() |
| | print(f"Starting SMILES augmentation with {num_workers} worker processes...") |
| |
|
| | |
| | with Pool(num_workers) as p: |
| | |
| | results = list(tqdm(p.imap(create_augmented_pair, smiles_list), total=len(smiles_list), desc="Augmenting Pairs")) |
| |
|
| | |
| | print("Processing complete. Converting to DataFrame...") |
| | |
| | df = pd.DataFrame(results, columns=['smiles_1', 'smiles_2']) |
| |
|
| | |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| | |
| | print(f"Saving augmented pairs to '{output_path}'...") |
| | |
| | df.to_parquet(output_path) |
| | |
| | print("All done. Your pre-computed dataset is ready!") |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|
| |
|