| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import Dataset, DataLoader |
| | import numpy as np |
| | import json |
| | import os |
| | from typing import Dict, List, Tuple, Optional |
| | import random |
| | import re |
| |
|
| | def parse_fasta_with_amp_labels(fasta_path: str, max_seq_len: int = 50) -> Dict[str, any]: |
| | """ |
| | Parse FASTA file and assign AMP/Non-AMP labels based on header prefixes. |
| | |
| | Label assignment strategy: |
| | - AMP (0): Headers starting with '>AP' |
| | - Non-AMP (1): Headers starting with '>sp' |
| | - Mask (2): Used for CFG training (randomly assigned) |
| | |
| | File format: |
| | - Odd lines: Headers (>sp or >AP) |
| | - Even lines: Amino acid sequences |
| | |
| | Args: |
| | fasta_path: Path to FASTA file |
| | max_seq_len: Maximum sequence length to include |
| | |
| | Returns: |
| | Dictionary with sequences, labels, and metadata |
| | """ |
| | sequences = [] |
| | labels = [] |
| | headers = [] |
| | |
| | print(f"Parsing FASTA file: {fasta_path}") |
| | print("Label assignment: >AP = AMP (0), >sp = Non-AMP (1)") |
| | |
| | current_header = "" |
| | current_sequence = "" |
| | |
| | with open(fasta_path, 'r') as f: |
| | for line in f: |
| | line = line.strip() |
| | if line.startswith('>'): |
| | |
| | if current_sequence and current_header: |
| | if 2 <= len(current_sequence) <= max_seq_len: |
| | |
| | canonical_aa = set('ACDEFGHIKLMNPQRSTVWY') |
| | if all(aa in canonical_aa for aa in current_sequence.upper()): |
| | sequences.append(current_sequence.upper()) |
| | headers.append(current_header) |
| | |
| | |
| | if current_header.startswith('AP'): |
| | labels.append(0) |
| | elif current_header.startswith('sp'): |
| | labels.append(1) |
| | else: |
| | |
| | labels.append(1) |
| | print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP") |
| | |
| | |
| | current_header = line[1:] |
| | current_sequence = "" |
| | else: |
| | current_sequence += line |
| | |
| | |
| | if current_sequence and current_header: |
| | if 2 <= len(current_sequence) <= max_seq_len: |
| | canonical_aa = set('ACDEFGHIKLMNPQRSTVWY') |
| | if all(aa in canonical_aa for aa in current_sequence.upper()): |
| | sequences.append(current_sequence.upper()) |
| | headers.append(current_header) |
| | |
| | |
| | if current_header.startswith('AP'): |
| | labels.append(0) |
| | elif current_header.startswith('sp'): |
| | labels.append(1) |
| | else: |
| | |
| | labels.append(1) |
| | print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP") |
| | |
| | |
| | original_labels = np.array(labels) |
| | masked_labels = original_labels.copy() |
| | mask_probability = 0.1 |
| | mask_indices = np.random.choice( |
| | len(original_labels), |
| | size=int(len(original_labels) * mask_probability), |
| | replace=False |
| | ) |
| | masked_labels[mask_indices] = 2 |
| | |
| | print(f"✓ Parsed {len(sequences)} valid sequences from FASTA") |
| | print(f" AMP sequences: {np.sum(original_labels == 0)}") |
| | print(f" Non-AMP sequences: {np.sum(original_labels == 1)}") |
| | print(f" Masked for CFG: {len(mask_indices)}") |
| | |
| | return { |
| | 'sequences': sequences, |
| | 'headers': headers, |
| | 'labels': original_labels.tolist(), |
| | 'masked_labels': masked_labels.tolist(), |
| | 'mask_indices': mask_indices.tolist() |
| | } |
| |
|
| | class CFGUniProtDataset(Dataset): |
| | """ |
| | Dataset class for UniProt sequences with classifier-free guidance. |
| | |
| | This dataset: |
| | 1. Loads processed UniProt data with AMP classifications |
| | 2. Handles label masking for CFG training |
| | 3. Integrates with your existing flow training pipeline |
| | 4. Provides sequences, labels, and masking information |
| | """ |
| | |
| | def __init__(self, |
| | data_path: str, |
| | use_masked_labels: bool = True, |
| | mask_probability: float = 0.1, |
| | max_seq_len: int = 50, |
| | device: str = 'cuda'): |
| | |
| | self.data_path = data_path |
| | self.use_masked_labels = use_masked_labels |
| | self.mask_probability = mask_probability |
| | self.max_seq_len = max_seq_len |
| | self.device = device |
| | |
| | |
| | self._load_data() |
| | |
| | |
| | self.label_map = { |
| | 0: 'amp', |
| | 1: 'non_amp', |
| | 2: 'mask' |
| | } |
| | |
| | print(f"CFG Dataset initialized:") |
| | print(f" Total sequences: {len(self.sequences)}") |
| | print(f" Using masked labels: {use_masked_labels}") |
| | print(f" Mask probability: {mask_probability}") |
| | print(f" Label distribution: {self._get_label_distribution()}") |
| | |
| | def _load_data(self): |
| | """Load processed UniProt data.""" |
| | if os.path.exists(self.data_path): |
| | with open(self.data_path, 'r') as f: |
| | data = json.load(f) |
| | |
| | self.sequences = data['sequences'] |
| | self.original_labels = np.array(data['original_labels']) |
| | self.masked_labels = np.array(data['masked_labels']) |
| | self.mask_indices = set(data['mask_indices']) |
| | |
| | else: |
| | raise FileNotFoundError(f"Data file not found: {self.data_path}") |
| | |
| | def _get_label_distribution(self) -> Dict[str, int]: |
| | """Get distribution of labels in the dataset.""" |
| | labels = self.masked_labels if self.use_masked_labels else self.original_labels |
| | unique, counts = np.unique(labels, return_counts=True) |
| | return {self.label_map[label]: count for label, count in zip(unique, counts)} |
| | |
| | def __len__(self) -> int: |
| | return len(self.sequences) |
| | |
| | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| | """Get a single sample with sequence and label.""" |
| | sequence = self.sequences[idx] |
| | |
| | |
| | if self.use_masked_labels: |
| | label = self.masked_labels[idx] |
| | else: |
| | label = self.original_labels[idx] |
| | |
| | |
| | is_masked = idx in self.mask_indices |
| | |
| | return { |
| | 'sequence': sequence, |
| | 'label': torch.tensor(label, dtype=torch.long), |
| | 'original_label': torch.tensor(self.original_labels[idx], dtype=torch.long), |
| | 'is_masked': torch.tensor(is_masked, dtype=torch.bool), |
| | 'index': torch.tensor(idx, dtype=torch.long) |
| | } |
| | |
| | def get_label_statistics(self) -> Dict[str, Dict]: |
| | """Get detailed statistics about labels.""" |
| | stats = { |
| | 'original': self._get_label_distribution(), |
| | 'masked': self._get_label_distribution() if self.use_masked_labels else None, |
| | 'masking_info': { |
| | 'total_masked': len(self.mask_indices), |
| | 'mask_probability': self.mask_probability, |
| | 'masked_indices': list(self.mask_indices) |
| | } |
| | } |
| | return stats |
| |
|
| | class CFGFlowDataset(Dataset): |
| | """ |
| | Dataset that integrates CFG labels with your existing flow training pipeline. |
| | |
| | This dataset: |
| | 1. Loads your existing AMP embeddings |
| | 2. Adds CFG labels from UniProt processing |
| | 3. Handles the integration between embeddings and labels |
| | 4. Provides data in the format expected by your flow training |
| | """ |
| | |
| | def __init__(self, |
| | embeddings_path: str, |
| | cfg_data_path: str, |
| | use_masked_labels: bool = True, |
| | max_seq_len: int = 50, |
| | device: str = 'cuda'): |
| | |
| | self.embeddings_path = embeddings_path |
| | self.cfg_data_path = cfg_data_path |
| | self.use_masked_labels = use_masked_labels |
| | self.max_seq_len = max_seq_len |
| | self.device = device |
| | |
| | |
| | self._load_embeddings() |
| | self._load_cfg_data() |
| | self._align_data() |
| | |
| | print(f"CFG Flow Dataset initialized:") |
| | print(f" AMP embeddings: {self.embeddings.shape}") |
| | print(f" CFG labels: {len(self.cfg_labels)}") |
| | print(f" Aligned samples: {len(self.aligned_indices)}") |
| | |
| | def _load_embeddings(self): |
| | """Load your existing AMP embeddings.""" |
| | print(f"Loading AMP embeddings from {self.embeddings_path}...") |
| | |
| | |
| | combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt") |
| | |
| | if os.path.exists(combined_path): |
| | print(f"Loading combined embeddings from {combined_path} (FULL DATA)...") |
| | |
| | self.embeddings = torch.load(combined_path, map_location='cpu') |
| | print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}") |
| | else: |
| | print("Combined embeddings file not found, loading individual files...") |
| | |
| | import glob |
| | |
| | embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt")) |
| | embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')] |
| | |
| | print(f"Found {len(embedding_files)} individual embedding files") |
| | |
| | |
| | embeddings_list = [] |
| | for file_path in embedding_files: |
| | try: |
| | embedding = torch.load(file_path, map_location='cpu') |
| | if embedding.dim() == 2: |
| | embeddings_list.append(embedding) |
| | else: |
| | print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") |
| | except Exception as e: |
| | print(f"Warning: Could not load {file_path}: {e}") |
| | |
| | if not embeddings_list: |
| | raise ValueError("No valid embeddings found!") |
| | |
| | self.embeddings = torch.stack(embeddings_list) |
| | print(f"Loaded {len(self.embeddings)} embeddings from individual files") |
| | |
| | def _load_cfg_data(self): |
| | """Load CFG data from FASTA file with automatic AMP labeling.""" |
| | print(f"Loading CFG data from FASTA: {self.cfg_data_path}...") |
| | |
| | |
| | if self.cfg_data_path.endswith('.fasta') or self.cfg_data_path.endswith('.fa'): |
| | |
| | cfg_data = parse_fasta_with_amp_labels(self.cfg_data_path, self.max_seq_len) |
| | |
| | self.cfg_sequences = cfg_data['sequences'] |
| | self.cfg_headers = cfg_data['headers'] |
| | self.cfg_original_labels = np.array(cfg_data['labels']) |
| | self.cfg_masked_labels = np.array(cfg_data['masked_labels']) |
| | self.cfg_mask_indices = set(cfg_data['mask_indices']) |
| | |
| | else: |
| | |
| | with open(self.cfg_data_path, 'r') as f: |
| | cfg_data = json.load(f) |
| | |
| | self.cfg_sequences = cfg_data['sequences'] |
| | self.cfg_headers = cfg_data.get('headers', [''] * len(cfg_data['sequences'])) |
| | self.cfg_original_labels = np.array(cfg_data['labels']) |
| | |
| | |
| | |
| | self.cfg_masked_labels = self.cfg_original_labels.copy() |
| | mask_probability = 0.1 |
| | mask_indices = np.random.choice( |
| | len(self.cfg_original_labels), |
| | size=int(len(self.cfg_original_labels) * mask_probability), |
| | replace=False |
| | ) |
| | self.cfg_masked_labels[mask_indices] = 2 |
| | self.cfg_mask_indices = set(mask_indices) |
| | |
| | print(f"Loaded {len(self.cfg_sequences)} CFG sequences") |
| | print(f"Label distribution: {np.bincount(self.cfg_original_labels)}") |
| | print(f"Masked {len(self.cfg_mask_indices)} labels for CFG training") |
| | |
| | def _align_data(self): |
| | """Align AMP embeddings with CFG data based on sequence matching.""" |
| | print("Aligning AMP embeddings with CFG data...") |
| | |
| | |
| | |
| | min_samples = min(len(self.embeddings), len(self.cfg_sequences)) |
| | |
| | self.aligned_indices = list(range(min_samples)) |
| | |
| | |
| | if self.use_masked_labels: |
| | self.cfg_labels = self.cfg_masked_labels[:min_samples] |
| | else: |
| | self.cfg_labels = self.cfg_original_labels[:min_samples] |
| | |
| | |
| | self.aligned_embeddings = self.embeddings[:min_samples] |
| | |
| | print(f"Aligned {min_samples} samples") |
| | |
| | def __len__(self) -> int: |
| | return len(self.aligned_indices) |
| | |
| | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| | """Get a single sample with embedding and CFG label.""" |
| | |
| | embedding = self.aligned_embeddings[idx] |
| | label = self.cfg_labels[idx] |
| | original_label = self.cfg_original_labels[idx] |
| | is_masked = idx in self.cfg_mask_indices |
| | |
| | return { |
| | 'embedding': embedding, |
| | 'label': torch.tensor(label, dtype=torch.long), |
| | 'original_label': torch.tensor(original_label, dtype=torch.long), |
| | 'is_masked': torch.tensor(is_masked, dtype=torch.bool), |
| | 'index': torch.tensor(idx, dtype=torch.long) |
| | } |
| | |
| | def get_embedding_stats(self) -> Dict: |
| | """Get statistics about the embeddings.""" |
| | return { |
| | 'shape': self.aligned_embeddings.shape, |
| | 'mean': self.aligned_embeddings.mean().item(), |
| | 'std': self.aligned_embeddings.std().item(), |
| | 'min': self.aligned_embeddings.min().item(), |
| | 'max': self.aligned_embeddings.max().item() |
| | } |
| |
|
| | def create_cfg_dataloader(dataset: Dataset, |
| | batch_size: int = 32, |
| | shuffle: bool = True, |
| | num_workers: int = 4) -> DataLoader: |
| | """Create a DataLoader for CFG training.""" |
| | |
| | def collate_fn(batch): |
| | """Custom collate function for CFG data.""" |
| | |
| | embeddings = torch.stack([item['embedding'] for item in batch]) |
| | labels = torch.stack([item['label'] for item in batch]) |
| | original_labels = torch.stack([item['original_label'] for item in batch]) |
| | is_masked = torch.stack([item['is_masked'] for item in batch]) |
| | indices = torch.stack([item['index'] for item in batch]) |
| | |
| | return { |
| | 'embeddings': embeddings, |
| | 'labels': labels, |
| | 'original_labels': original_labels, |
| | 'is_masked': is_masked, |
| | 'indices': indices |
| | } |
| | |
| | return DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | shuffle=shuffle, |
| | num_workers=num_workers, |
| | collate_fn=collate_fn, |
| | pin_memory=True |
| | ) |
| |
|
| | def test_cfg_dataset(): |
| | """Test function to verify the CFG dataset works correctly.""" |
| | print("Testing CFG Dataset...") |
| | |
| | |
| | test_data = { |
| | 'sequences': ['MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG', |
| | 'MKLLIVTFCLTFAAL', |
| | 'MKLLIVTFCLTFAALMKLLIVTFCLTFAAL'], |
| | 'original_labels': [0, 1, 0], |
| | 'masked_labels': [0, 2, 0], |
| | 'mask_indices': [1] |
| | } |
| | |
| | |
| | test_path = 'test_cfg_data.json' |
| | with open(test_path, 'w') as f: |
| | json.dump(test_data, f) |
| | |
| | |
| | dataset = CFGUniProtDataset(test_path, use_masked_labels=True) |
| | |
| | print(f"Dataset length: {len(dataset)}") |
| | for i in range(len(dataset)): |
| | sample = dataset[i] |
| | print(f"Sample {i}:") |
| | print(f" Sequence: {sample['sequence'][:20]}...") |
| | print(f" Label: {sample['label'].item()}") |
| | print(f" Original Label: {sample['original_label'].item()}") |
| | print(f" Is Masked: {sample['is_masked'].item()}") |
| | |
| | |
| | os.remove(test_path) |
| | print("Test completed successfully!") |
| |
|
| | if __name__ == "__main__": |
| | test_cfg_dataset() |