import glob import os import sys import torch import torch.nn as nn from omegaconf import OmegaConf from safetensors.torch import load_file # Add audio-embeddings to path dynamically # We assume audio-embeddings is a sibling directory to xares-llm or provided via env var # Prioritize absolute path if known, otherwise relative POSSIBLE_PATHS = [ # "/media/ltuncay/Shared-4TB/dev/audio-embeddings", os.path.abspath(os.path.join(os.path.dirname(__file__), "audio-embeddings")), # os.path.abspath(os.path.join(os.getcwd(), "../audio-embeddings")), ] AUDIO_EMBEDDINGS_PATH = None for p in POSSIBLE_PATHS: if os.path.exists(p): AUDIO_EMBEDDINGS_PATH = p break if AUDIO_EMBEDDINGS_PATH: if AUDIO_EMBEDDINGS_PATH not in sys.path: sys.path.append(AUDIO_EMBEDDINGS_PATH) print(f"Added {AUDIO_EMBEDDINGS_PATH} to sys.path") else: print( "Warning: audio-embeddings path not found. Imports may fail if not installed in environment." ) try: from src.models.best_rq2_module import BestRQ2Module except ImportError as e: raise ImportError( f"Could not import src.models.best_rq2_module. Ensure audio-embeddings is correctly located or installed. Error: {e}" ) class BestRQ2Encoder(nn.Module): def __init__(self, checkpoint_path=None, model_config_path=None, **kwargs): super().__init__() base_path = os.path.dirname(__file__) model_config_path = os.path.join(base_path, "config.yaml") checkpoint_path = os.path.join(base_path, "BEST-RQ-2.safetensors") if not os.path.exists(model_config_path): raise FileNotFoundError(f"Config not found at {model_config_path}") if not checkpoint_path or not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") print(f"Loading BestRQ2 config from {model_config_path}") cfg = OmegaConf.load(model_config_path) print(f"Loading BestRQ2 checkpoint from {checkpoint_path}") # Reconstruct model args from config model_cfg = cfg.model net_cfg = model_cfg.net # Instantiate model # Note: BestRQ2Module inherits from LightningModule self.module = BestRQ2Module( optimizer=None, # Not needed for inference net=net_cfg, warmup_pct=model_cfg.get("warmup_pct", 0.1), final_lr_ratio=model_cfg.get("final_lr_ratio", 0.001), spectrogram_adjustment_mode=model_cfg.get( "spectrogram_adjustment_mode", "pad" ), codebook_dim=model_cfg.get("codebook_dim", 16), vocab_size=model_cfg.get("vocab_size", 8192), criterion=None, ) # Load weights try: state_dict = load_file(checkpoint_path) except Exception as e: print(f"Error loading safetensors: {e}. Trying torch.load...") state_dict = torch.load(checkpoint_path, map_location="cpu") if "state_dict" in state_dict: state_dict = state_dict["state_dict"] # Handle 'module.' prefix if present in checkpoint vs model # Usually LightningModules save with state_dict keys matching model attributes. # But sometimes they might be wrapped. # We will try loading strict=False and inspect. missing, unexpected = self.module.load_state_dict(state_dict, strict=False) if missing: # Check if prefixes match # If all missing keys start with something common, or if state_dict has prefixes print(f"Warning: {len(missing)} keys missing during loading.") # print(missing[:5]) if unexpected: print(f"Warning: {len(unexpected)} keys unexpected during loading.") self.module.eval() self.output_dim = net_cfg.encoder.embed_dim # Extract dynamic parameters for length handling try: # 1. Sample Rate & Hop Length (from Spectrogram) # BestRQ2Module -> Spectrogram -> MelSpectrogram -> hop_length self.sample_rate = self.module.spectrogram.mel_spec.sample_rate self.hop_length = self.module.spectrogram.mel_spec.hop_length # 2. Patch Size (Time dimension) # BestRQ2Module -> PatchEmbed -> patch_size (H, W) -> W is time self.patch_size_time = self.module.patch_embed.patch_size[1] # 3. Max Input Frames (Time dimension) # BestRQ2Module -> PatchEmbed -> img_size (H, W) -> W is time frames self.max_frames = self.module.patch_embed.img_size[1] # Calculations # Minimum samples required to get at least 1 patch width in spectrogram # We need T_spec >= patch_size_time # T_spec = T_samples // hop_length (roughly) # So T_samples >= patch_size_time * hop_length self.min_samples = self.patch_size_time * self.hop_length # Chunk size: The maximum audio length the model's positional embeddings can handle # T_samples_max = max_frames * hop_length self.chunk_samples = self.max_frames * self.hop_length print( f"BestRQ2Encoder constraints: Min Samples={self.min_samples}, Chunk Samples={self.chunk_samples}" ) except Exception as e: print(f"Warning: Could not extract dynamic length constraints: {e}") print("Falling back to safe defaults (1s min, 10s chunk)") self.min_samples = 16000 self.chunk_samples = 16000 * 10 def _forward_chunk(self, audio_chunk: torch.Tensor) -> torch.Tensor: """Helper to process a single time-chunk of audio.""" # Determine target device from the spectrogram window (safest for STFT) try: target_device = self.module.spectrogram.mel_spec.spectrogram.window.device except AttributeError: if hasattr(self.module.spectrogram.mel_spec, "window"): target_device = self.module.spectrogram.mel_spec.window.device else: target_device = self.module.device if audio_chunk.device != target_device: audio_chunk = audio_chunk.to(target_device) # BestRQ2Module expects [B, C, T] if audio_chunk.ndim == 2: audio_chunk = audio_chunk.unsqueeze(1) # [B, 1, T] # _process_audio returns (patches, grid_size) patches, grid_size = self.module._process_audio(audio_chunk) # Create Dummy Mask (all False = keep all) B, N, D = patches.shape mask = torch.zeros((B, N), dtype=torch.bool, device=patches.device) # Compute encoder encoder_out = self.module.compute_encoder(patches, mask, grid_size) return encoder_out def forward( self, audio: torch.Tensor, audio_attention_mask=None ) -> tuple[torch.Tensor, torch.Tensor | None]: # audio: [B, T] if audio.ndim == 1: audio = audio.unsqueeze(0) B, T = audio.shape # 1. Handle Short Audio (Whole Batch) if T < self.min_samples: pad_amt = self.min_samples - T audio = torch.nn.functional.pad(audio, (0, pad_amt)) T = self.min_samples # Update T # 2. Sequential Chunking if T <= self.chunk_samples: # Single chunk processing return self._forward_chunk(audio), None else: # Split into chunks of max length chunks = torch.split(audio, self.chunk_samples, dim=1) outputs = [] for chunk in chunks: # Handle potentially short last chunk chunk_len = chunk.shape[1] if chunk_len < self.min_samples: pad_amt = self.min_samples - chunk_len chunk = torch.nn.functional.pad(chunk, (0, pad_amt)) # Process out_chunk = self._forward_chunk(chunk) # If we padded the last chunk solely to meet min_samples, # should we slice? BestRQ2 output is patches. # 1 patch covers `min_samples`. # If original was < 1 patch, we produced 1 patch. # We can't slice sub-patch. We just return the 1 patch. outputs.append(out_chunk) # Concatenate along sequence dimension (dim=1) final_output = torch.cat(outputs, dim=1) return final_output, None if __name__ == "__main__": try: mdl = BestRQ2Encoder() print("Model initialized successfully") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mdl.module.to(device) x = torch.randn(1, 160000).to(device) y, _ = mdl(x) print(f"Output shape: {y.shape}") except Exception as e: print(f"Error testing model: {e}") import traceback traceback.print_exc()