| | |
| | |
| | |
| | import os |
| | import warnings |
| | import wandb |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, Dataset |
| | import numpy as np |
| | from tqdm import tqdm |
| | from rdkit import Chem, RDLogger |
| | from datasets import load_dataset, load_from_disk |
| | from transformers import AutoTokenizer, BertModel, BertConfig |
| | import pandas as pd |
| |
|
| | |
| | |
| | |
| | |
| | RDLogger.DisableLog('rdApp.*') |
| | |
| | warnings.filterwarnings("ignore") |
| |
|
| | |
| | |
| | |
| | def global_average_pooling(x): |
| | """Global Average Pooling: from [B, max_len, hid_dim] to [B, hid_dim]""" |
| | return torch.mean(x, dim=1) |
| |
|
| | class SimSonEncoder(nn.Module): |
| | """The main encoder model based on BERT.""" |
| | def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1): |
| | super(SimSonEncoder, self).__init__() |
| | self.bert = BertModel(config, add_pooling_layer=False) |
| | self.linear = nn.Linear(config.hidden_size, max_len) |
| | self.dropout = nn.Dropout(dropout) |
| | |
| | def forward(self, input_ids, attention_mask=None): |
| | if attention_mask is None: |
| | attention_mask = input_ids.ne(self.bert.config.pad_token_id) |
| | |
| | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| | hidden_states = self.dropout(outputs.last_hidden_state) |
| | pooled_output = global_average_pooling(hidden_states) |
| | return self.linear(pooled_output) |
| |
|
| | class ContrastiveLoss(nn.Module): |
| | """Calculates the contrastive loss for the SimSon model.""" |
| | def __init__(self, temperature=0.2): |
| | super(ContrastiveLoss, self).__init__() |
| | self.temperature = temperature |
| | self.similarity_fn = F.cosine_similarity |
| |
|
| | def forward(self, proj_1, proj_2): |
| | batch_size = proj_1.shape[0] |
| | device = proj_1.device |
| | |
| | |
| | z_i = F.normalize(proj_1, p=2, dim=1) |
| | z_j = F.normalize(proj_2, p=2, dim=1) |
| | |
| | |
| | representations = torch.cat([z_i, z_j], dim=0) |
| | |
| | |
| | similarity_matrix = self.similarity_fn(representations.unsqueeze(1), representations.unsqueeze(0), dim=2) |
| | |
| | |
| | sim_ij = torch.diag(similarity_matrix, batch_size) |
| | sim_ji = torch.diag(similarity_matrix, -batch_size) |
| | positives = torch.cat([sim_ij, sim_ji], dim=0) |
| | |
| | |
| | nominator = torch.exp(positives / self.temperature) |
| | mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool, device=device)).float() |
| | denominator = mask * torch.exp(similarity_matrix / self.temperature) |
| | |
| | |
| | loss = -torch.log(nominator / torch.sum(denominator, dim=1)) |
| | return torch.sum(loss) / (2 * batch_size) |
| |
|
| | |
| | |
| | |
| | class SmilesEnumerator: |
| | """Generates randomized SMILES strings for data augmentation.""" |
| | def randomize_smiles(self, smiles): |
| | try: |
| | mol = Chem.MolFromSmiles(smiles) |
| | return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles |
| | except: |
| | return smiles |
| |
|
| | class ContrastiveSmilesDataset(Dataset): |
| | """Dataset for creating pairs of augmented SMILES for contrastive learning.""" |
| | def __init__(self, smiles_list, tokenizer, max_length=512): |
| | self.smiles_list = smiles_list |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| | self.enumerator = SmilesEnumerator() |
| |
|
| | def __len__(self): |
| | return len(self.smiles_list) |
| |
|
| | def __getitem__(self, idx): |
| | original_smiles = self.smiles_list[idx] |
| | |
| | |
| | smiles_1 = self.enumerator.randomize_smiles(original_smiles) |
| | smiles_2 = self.enumerator.randomize_smiles(original_smiles) |
| | |
| | |
| | tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length') |
| | tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length') |
| | |
| | return { |
| | 'input_ids_1': torch.tensor(tokens_1['input_ids']), |
| | 'attention_mask_1': torch.tensor(tokens_1['attention_mask']), |
| | 'input_ids_2': torch.tensor(tokens_2['input_ids']), |
| | 'attention_mask_2': torch.tensor(tokens_2['attention_mask']), |
| | } |
| |
|
| | class PrecomputedContrastiveSmilesDataset(Dataset): |
| | """ |
| | A Dataset class that reads pre-augmented SMILES pairs from a Parquet file. |
| | This is significantly faster as it offloads the expensive SMILES randomization |
| | to a one-time preprocessing step. |
| | """ |
| | def __init__(self, tokenizer, file_path: str, max_length: int = 512): |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| | |
| | |
| | |
| | print(f"Loading pre-computed data from {file_path}...") |
| | self.data = pd.read_parquet(file_path) |
| | print("Data loaded successfully.") |
| |
|
| | def __len__(self): |
| | """Returns the total number of pairs in the dataset.""" |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | """ |
| | Retrieves a pre-augmented pair, tokenizes it, and returns it |
| | in the format expected by the DataCollator. |
| | """ |
| | |
| | row = self.data.iloc[idx] |
| | smiles_1 = row['smiles_1'] |
| | smiles_2 = row['smiles_2'] |
| | |
| | |
| | tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length') |
| | tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length') |
| | |
| | return { |
| | 'input_ids_1': torch.tensor(tokens_1['input_ids']), |
| | 'attention_mask_1': torch.tensor(tokens_1['attention_mask']), |
| | 'input_ids_2': torch.tensor(tokens_2['input_ids']), |
| | 'attention_mask_2': torch.tensor(tokens_2['attention_mask']), |
| | } |
| |
|
| | class PreTokenizedSmilesDataset(Dataset): |
| | """ |
| | A Dataset that loads a pre-tokenized and pre-padded dataset created |
| | by the preprocessing script. It uses memory-mapping for instant loads |
| | and high efficiency. |
| | """ |
| | def __init__(self, dataset_path: str): |
| | |
| | self.dataset = load_from_disk(dataset_path) |
| | |
| | self.dataset.set_format(type='torch', columns=[ |
| | 'input_ids_1', 'attention_mask_1', 'input_ids_2', 'attention_mask_2' |
| | ]) |
| | print(f"Successfully loaded pre-tokenized dataset from {dataset_path}.") |
| |
|
| | def __len__(self): |
| | """Returns the total number of items in the dataset.""" |
| | return len(self.dataset) |
| |
|
| | def __getitem__(self, idx): |
| | """Retrieves a single pre-processed item.""" |
| | return self.dataset[idx] |
| |
|
| |
|
| | class DataCollatorWithPadding: |
| | """ |
| | A collate function that dynamically pads inputs to the longest sequence |
| | across both augmented views in the batch, ensuring consistent tensor shapes. |
| | """ |
| | def __init__(self, tokenizer): |
| | self.tokenizer = tokenizer |
| |
|
| | def __call__(self, features): |
| | |
| | combined_features = [] |
| | for feature in features: |
| | combined_features.append({'input_ids': feature['input_ids_1'], 'attention_mask': feature['attention_mask_1']}) |
| | combined_features.append({'input_ids': feature['input_ids_2'], 'attention_mask': feature['attention_mask_2']}) |
| |
|
| | |
| | padded_combined = self.tokenizer.pad(combined_features, padding='longest', return_tensors='pt') |
| |
|
| | |
| | batch_size = len(features) |
| | input_ids_1, input_ids_2 = torch.split(padded_combined['input_ids'], batch_size, dim=0) |
| | attention_mask_1, attention_mask_2 = torch.split(padded_combined['attention_mask'], batch_size, dim=0) |
| | |
| | return { |
| | 'input_ids_1': input_ids_1, |
| | 'attention_mask_1': attention_mask_1, |
| | 'input_ids_2': input_ids_2, |
| | 'attention_mask_2': attention_mask_2, |
| | } |
| |
|
| | |
| | |
| | |
| | def evaluation_step(model, batch, criterion, device): |
| | """Performs a single evaluation step on a batch of data.""" |
| | input_ids_1 = batch['input_ids_1'].to(device) |
| | attention_mask_1 = batch['attention_mask_1'].to(device) |
| | input_ids_2 = batch['input_ids_2'].to(device) |
| | attention_mask_2 = batch['attention_mask_2'].to(device) |
| | |
| | combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0) |
| | combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0) |
| | |
| | with torch.no_grad(): |
| | combined_proj = model(combined_input_ids, combined_attention_mask) |
| | |
| | batch_size = input_ids_1.size(0) |
| | proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0) |
| | |
| | loss = criterion(proj_1, proj_2) |
| | return proj_1, proj_2, loss |
| |
|
| | def train_epoch(model, train_loader, optimizer, criterion, device, scheduler, save_path, save_steps): |
| | model.train() |
| | total_loss = 0 |
| | progress_bar = tqdm(train_loader, desc="Training Batch", leave=False) |
| |
|
| | for step, batch in enumerate(progress_bar, 1): |
| | input_ids_1 = batch['input_ids_1'].to(device) |
| | attention_mask_1 = batch['attention_mask_1'].to(device) |
| | input_ids_2 = batch['input_ids_2'].to(device) |
| | attention_mask_2 = batch['attention_mask_2'].to(device) |
| | |
| | optimizer.zero_grad() |
| | with torch.autocast(dtype=torch.float16, device_type="cuda"): |
| | combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0) |
| | combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0) |
| | |
| | combined_proj = model(combined_input_ids, combined_attention_mask) |
| | |
| | batch_size = input_ids_1.size(0) |
| | proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0) |
| | |
| | loss = criterion(proj_1, proj_2) |
| |
|
| | loss.backward() |
| |
|
| | optimizer.step() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| | scheduler.step() |
| | |
| | total_loss += loss.item() |
| | |
| | progress_bar.set_postfix(loss=f"{loss.item():.4f}") |
| | wandb.log({ |
| | "train_batch_loss": loss.item(), |
| | "learning_rate": scheduler.get_last_lr()[0] |
| | }) |
| | if save_path and step % save_steps == 0: |
| | torch.save(model.state_dict(), save_path) |
| | progress_bar.write(f"Checkpoint saved at step {step}") |
| | |
| | return total_loss / len(train_loader) |
| |
|
| | def validate_epoch(model, val_loader, criterion, device): |
| | model.eval() |
| | total_loss = 0 |
| | progress_bar = tqdm(val_loader, desc="Validating", leave=False) |
| |
|
| | for batch in progress_bar: |
| | _, _, loss = evaluation_step(model, batch, criterion, device) |
| | total_loss += loss.item() |
| | print(f'Validation loss: {total_loss / len(val_loader)}') |
| | return total_loss / len(val_loader) |
| |
|
| | def test_model(model, test_loader, criterion, device): |
| | model.eval() |
| | total_loss = 0 |
| | all_similarities = [] |
| | progress_bar = tqdm(test_loader, desc="Testing", leave=False) |
| |
|
| | for batch in progress_bar: |
| | proj_1, proj_2, loss = evaluation_step(model, batch, criterion, device) |
| | total_loss += loss.item() |
| | |
| | proj_1_norm = F.normalize(proj_1, p=2, dim=1) |
| | proj_2_norm = F.normalize(proj_2, p=2, dim=1) |
| | batch_similarities = F.cosine_similarity(proj_1_norm, proj_2_norm, dim=1) |
| | all_similarities.extend(batch_similarities.cpu().numpy()) |
| |
|
| | avg_loss = total_loss / len(test_loader) |
| | avg_sim = np.mean(all_similarities) |
| | std_sim = np.std(all_similarities) |
| | |
| | return avg_loss, avg_sim, std_sim |
| |
|
| | |
| | |
| | |
| | def run_training(model_config, hparams, data_splits): |
| | """The main function to run the training and evaluation process.""" |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {device}") |
| | |
| | wandb_key = os.getenv("WANDB_API_KEY") |
| | if wandb_key: |
| | wandb.login(key=wandb_key) |
| | wandb.init( |
| | project="simson-contrastive-learning-single-gpu", |
| | name=f"run-{wandb.util.generate_id()}", |
| | config=hparams |
| | ) |
| | train_smiles, val_smiles, test_smiles = data_splits |
| |
|
| |
|
| | tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR') |
| |
|
| | precomputed_train_path = 'data/splits/train.parquet' |
| | precomputed_test_path = 'data/splits/test.parquet' |
| | precomputed_val_path = 'data/splits/validation.parquet' |
| | |
| | train_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_train_path, max_length=hparams['max_length']) |
| | test_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_test_path, max_length=hparams['max_length']) |
| | val_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_val_path, max_length=hparams['max_length']) |
| | |
| | train_loader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True, num_workers=16, prefetch_factor=128, pin_memory=True) |
| | val_loader = DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True) |
| | test_loader = DataLoader(test_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True) |
| | print('Initialized all data. Compiling the model...') |
| | model = SimSonEncoder(config=model_config, max_len=hparams['max_embeddings']).to(device) |
| | model = torch.compile(model) |
| | print(model) |
| | total_params = sum(p.numel() for p in model.parameters()) |
| |
|
| | print(f"Total number of parameters: {total_params // 1_000_000} M") |
| | wandb.config.update({"total_params_M": total_params // 1_000_000}) |
| |
|
| | criterion = ContrastiveLoss(temperature=hparams['temperature']).to(device) |
| | optimizer = optim.AdamW(model.parameters(), lr=hparams['lr'], weight_decay=1e-5, fused=True) |
| | print(f"Len of dataloader is {len(train_loader)}, with bs: {len(train_loader) // hparams['batch_size']}") |
| | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_mult=1, T_0=int(hparams['epochs'] * len(train_loader))) |
| | print("Starting training...") |
| | wandb.watch(model, log='all', log_freq=5000) |
| | |
| | best_val_loss = float('inf') |
| | epoch_iterator = tqdm(range(hparams['epochs']), desc="Epochs") |
| | model.load_state_dict(torch.load(hparams['save_path'])) |
| | val_loss = validate_epoch(model, val_loader, criterion, device) |
| |
|
| | for epoch in epoch_iterator: |
| | train_loss = train_epoch(model, train_loader, optimizer, criterion, device, scheduler, hparams['save_path'], hparams['save_steps']) |
| | val_loss = validate_epoch(model, val_loader, criterion, device) |
| | epoch_iterator.set_postfix(train_loss=f"{train_loss:.4f}", val_loss=f"{val_loss:.4f}") |
| | wandb.log({ |
| | "epoch": epoch + 1, |
| | "train_epoch_loss": train_loss, |
| | "val_epoch_loss": val_loss, |
| | }) |
| | |
| | if val_loss < best_val_loss: |
| | best_val_loss = val_loss |
| | torch.save(model.state_dict(), hparams['save_path']) |
| | epoch_iterator.write(f"Epoch {epoch + 1}: New best model saved with val loss {val_loss:.4f}") |
| | |
| | epoch_iterator.write("Training complete. Starting final testing...") |
| | |
| | model.load_state_dict(torch.load(hparams['save_path'])) |
| | |
| | test_loss, avg_sim, std_sim = test_model(model, test_loader, criterion, device) |
| | |
| | print("\n--- Test Results ---") |
| | print(f"Test Loss: {test_loss:.4f}") |
| | print(f"Average Cosine Similarity: {avg_sim:.4f} \u00B1 {std_sim:.4f}") |
| | print("--------------------") |
| | |
| | wandb.log({ |
| | "test_loss": test_loss, |
| | "avg_cosine_similarity": avg_sim, |
| | "std_cosine_similarity": std_sim |
| | }) |
| | |
| | wandb.finish() |
| |
|
| | |
| | |
| | |
| | def main(): |
| | """Main function to configure and run the training process.""" |
| | hparams = { |
| | 'epochs': 1, |
| | 'lr': 1e-5, |
| | 'temperature': 0.05, |
| | 'batch_size': 64, |
| | 'max_length': 128, |
| | 'save_path': "simson_checkpoints/simson_model_single_gpu.bin", |
| | 'save_steps': 100_000, |
| | 'max_embeddings': 512, |
| | } |
| |
|
| | dataset = load_dataset('HoangHa/SMILES-250M')['train'] |
| | smiles_column_name = 'SMILES' |
| | |
| | total_size = len(dataset) |
| | test_size = int(0.1 * total_size) |
| | val_size = int(0.1 * (total_size - test_size)) |
| |
|
| | test_smiles = dataset.select(range(test_size))[smiles_column_name] |
| | val_smiles = dataset.select(range(test_size, test_size + val_size))[smiles_column_name] |
| | train_smiles = dataset.select(range(test_size + val_size, total_size))[smiles_column_name] |
| | data_splits = (train_smiles, val_smiles, test_smiles) |
| | tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR') |
| | model_config = BertConfig( |
| | vocab_size=tokenizer.vocab_size, |
| | hidden_size=768, |
| | num_hidden_layers=12, |
| | num_attention_heads=12, |
| | intermediate_size=2048, |
| | max_position_embeddings=512 |
| | ) |
| | save_dir = os.path.dirname(hparams['save_path']) |
| | if not os.path.exists(save_dir): |
| | os.makedirs(save_dir) |
| |
|
| | |
| | run_training(model_config, hparams, data_splits) |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|