| import glob |
| import os |
| import sys |
|
|
| import torch |
| import torch.nn as nn |
| from omegaconf import OmegaConf |
| from safetensors.torch import load_file |
|
|
| |
| |
| |
| POSSIBLE_PATHS = [ |
| |
| os.path.abspath(os.path.join(os.path.dirname(__file__), "audio-embeddings")), |
| |
| ] |
|
|
| AUDIO_EMBEDDINGS_PATH = None |
| for p in POSSIBLE_PATHS: |
| if os.path.exists(p): |
| AUDIO_EMBEDDINGS_PATH = p |
| break |
|
|
| if AUDIO_EMBEDDINGS_PATH: |
| if AUDIO_EMBEDDINGS_PATH not in sys.path: |
| sys.path.append(AUDIO_EMBEDDINGS_PATH) |
| print(f"Added {AUDIO_EMBEDDINGS_PATH} to sys.path") |
| else: |
| print( |
| "Warning: audio-embeddings path not found. Imports may fail if not installed in environment." |
| ) |
|
|
| try: |
| from src.models.best_rq2_module import BestRQ2Module |
| except ImportError as e: |
| raise ImportError( |
| f"Could not import src.models.best_rq2_module. Ensure audio-embeddings is correctly located or installed. Error: {e}" |
| ) |
|
|
|
|
| class BestRQ2Encoder(nn.Module): |
| def __init__(self, checkpoint_path=None, model_config_path=None, **kwargs): |
| super().__init__() |
|
|
| base_path = os.path.dirname(__file__) |
| model_config_path = os.path.join(base_path, "config.yaml") |
| checkpoint_path = os.path.join(base_path, "BEST-RQ-2.safetensors") |
|
|
| if not os.path.exists(model_config_path): |
| raise FileNotFoundError(f"Config not found at {model_config_path}") |
|
|
| if not checkpoint_path or not os.path.exists(checkpoint_path): |
| raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") |
|
|
| print(f"Loading BestRQ2 config from {model_config_path}") |
| cfg = OmegaConf.load(model_config_path) |
|
|
| print(f"Loading BestRQ2 checkpoint from {checkpoint_path}") |
|
|
| |
| model_cfg = cfg.model |
| net_cfg = model_cfg.net |
|
|
| |
| |
| self.module = BestRQ2Module( |
| optimizer=None, |
| net=net_cfg, |
| warmup_pct=model_cfg.get("warmup_pct", 0.1), |
| final_lr_ratio=model_cfg.get("final_lr_ratio", 0.001), |
| spectrogram_adjustment_mode=model_cfg.get( |
| "spectrogram_adjustment_mode", "pad" |
| ), |
| codebook_dim=model_cfg.get("codebook_dim", 16), |
| vocab_size=model_cfg.get("vocab_size", 8192), |
| criterion=None, |
| ) |
|
|
| |
| try: |
| state_dict = load_file(checkpoint_path) |
| except Exception as e: |
| print(f"Error loading safetensors: {e}. Trying torch.load...") |
| state_dict = torch.load(checkpoint_path, map_location="cpu") |
| if "state_dict" in state_dict: |
| state_dict = state_dict["state_dict"] |
|
|
| |
| |
| |
| |
|
|
| missing, unexpected = self.module.load_state_dict(state_dict, strict=False) |
| if missing: |
| |
| |
| print(f"Warning: {len(missing)} keys missing during loading.") |
| |
| if unexpected: |
| print(f"Warning: {len(unexpected)} keys unexpected during loading.") |
|
|
| self.module.eval() |
| self.output_dim = net_cfg.encoder.embed_dim |
|
|
| |
| try: |
| |
| |
| self.sample_rate = self.module.spectrogram.mel_spec.sample_rate |
| self.hop_length = self.module.spectrogram.mel_spec.hop_length |
|
|
| |
| |
| self.patch_size_time = self.module.patch_embed.patch_size[1] |
|
|
| |
| |
| self.max_frames = self.module.patch_embed.img_size[1] |
|
|
| |
| |
| |
| |
| |
| self.min_samples = self.patch_size_time * self.hop_length |
|
|
| |
| |
| self.chunk_samples = self.max_frames * self.hop_length |
|
|
| print( |
| f"BestRQ2Encoder constraints: Min Samples={self.min_samples}, Chunk Samples={self.chunk_samples}" |
| ) |
|
|
| except Exception as e: |
| print(f"Warning: Could not extract dynamic length constraints: {e}") |
| print("Falling back to safe defaults (1s min, 10s chunk)") |
| self.min_samples = 16000 |
| self.chunk_samples = 16000 * 10 |
|
|
| def _forward_chunk(self, audio_chunk: torch.Tensor) -> torch.Tensor: |
| """Helper to process a single time-chunk of audio.""" |
| |
| try: |
| target_device = self.module.spectrogram.mel_spec.spectrogram.window.device |
| except AttributeError: |
| if hasattr(self.module.spectrogram.mel_spec, "window"): |
| target_device = self.module.spectrogram.mel_spec.window.device |
| else: |
| target_device = self.module.device |
|
|
| if audio_chunk.device != target_device: |
| audio_chunk = audio_chunk.to(target_device) |
|
|
| |
| if audio_chunk.ndim == 2: |
| audio_chunk = audio_chunk.unsqueeze(1) |
|
|
| |
| patches, grid_size = self.module._process_audio(audio_chunk) |
|
|
| |
| B, N, D = patches.shape |
| mask = torch.zeros((B, N), dtype=torch.bool, device=patches.device) |
|
|
| |
| encoder_out = self.module.compute_encoder(patches, mask, grid_size) |
| return encoder_out |
|
|
| def forward( |
| self, audio: torch.Tensor, audio_attention_mask=None |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| |
| if audio.ndim == 1: |
| audio = audio.unsqueeze(0) |
|
|
| B, T = audio.shape |
|
|
| |
| if T < self.min_samples: |
| pad_amt = self.min_samples - T |
| audio = torch.nn.functional.pad(audio, (0, pad_amt)) |
| T = self.min_samples |
|
|
| |
| if T <= self.chunk_samples: |
| |
| return self._forward_chunk(audio), None |
| else: |
| |
| chunks = torch.split(audio, self.chunk_samples, dim=1) |
| outputs = [] |
|
|
| for chunk in chunks: |
| |
| chunk_len = chunk.shape[1] |
|
|
| if chunk_len < self.min_samples: |
| pad_amt = self.min_samples - chunk_len |
| chunk = torch.nn.functional.pad(chunk, (0, pad_amt)) |
|
|
| |
| out_chunk = self._forward_chunk(chunk) |
|
|
| |
| |
| |
| |
| |
|
|
| outputs.append(out_chunk) |
|
|
| |
| final_output = torch.cat(outputs, dim=1) |
|
|
| return final_output, None |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| mdl = BestRQ2Encoder() |
| print("Model initialized successfully") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| mdl.module.to(device) |
| x = torch.randn(1, 160000).to(device) |
| y, _ = mdl(x) |
| print(f"Output shape: {y.shape}") |
| except Exception as e: |
| print(f"Error testing model: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
|
|