Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| NVIDIA Nemo Codec Test - Gradio App | |
| Equivalent to snac_test.py but for the NVIDIA Nemo codec used in Kani TTS based models. | |
| Allows testing encode/decode cycles with the nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps model. | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| import numpy as np | |
| import traceback | |
| import time | |
| # Attempt to import Nemo | |
| try: | |
| from nemo.collections.tts.models import AudioCodecModel | |
| from nemo.utils.nemo_logging import Logger | |
| # Suppress Nemo logging | |
| nemo_logger = Logger() | |
| nemo_logger.remove_stream_handlers() | |
| print("Nemo modules imported successfully.") | |
| except ImportError as e: | |
| print(f"Error importing Nemo: {e}") | |
| raise ImportError("Could not import Nemo. Make sure 'nemo_toolkit[tts]' is installed correctly.") from e | |
| # --- Configuration --- | |
| TARGET_SR = 22050 # Nemo codec operates at 22kHz | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_NAME = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps" | |
| print(f"Using device: {DEVICE}") | |
| # --- Load Model (Load once globally) --- | |
| nemo_codec = None | |
| try: | |
| print(f"Loading Nemo codec model: {MODEL_NAME}...") | |
| start_time = time.time() | |
| nemo_codec = AudioCodecModel.from_pretrained(MODEL_NAME) | |
| nemo_codec = nemo_codec.to(DEVICE) | |
| nemo_codec.eval() # Set model to evaluation mode | |
| end_time = time.time() | |
| print(f"Nemo codec loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.") | |
| except Exception as e: | |
| print(f"FATAL: Error loading Nemo codec: {e}") | |
| print(traceback.format_exc()) | |
| # --- Main Processing Function --- | |
| def process_audio(audio_filepath): | |
| """ | |
| Loads, resamples, encodes, decodes audio using Nemo codec, and returns results. | |
| """ | |
| if nemo_codec is None: | |
| return None, None, None, "Error: Nemo codec could not be loaded. Cannot process audio." | |
| if audio_filepath is None: | |
| return None, None, None, "Please upload an audio file." | |
| logs = ["--- Starting Audio Processing with Nemo Codec ---"] | |
| try: | |
| # 1. Load Audio | |
| logs.append(f"Loading audio file: {audio_filepath}") | |
| load_start = time.time() | |
| original_waveform, original_sr = torchaudio.load(audio_filepath) | |
| load_end = time.time() | |
| logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s") | |
| # Ensure float32 | |
| original_waveform = original_waveform.to(dtype=torch.float32) | |
| # Handle multi-channel audio: Use the first channel | |
| if original_waveform.shape[0] > 1: | |
| logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.") | |
| original_waveform = original_waveform[0:1, :] # Keep channel dim for consistency | |
| # --- Prepare Original for Playback --- | |
| original_audio_playback = (original_sr, original_waveform.squeeze().numpy()) | |
| logs.append("Prepared original audio for playback.") | |
| # 2. Resample if necessary | |
| resample_start = time.time() | |
| if original_sr != TARGET_SR: | |
| logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...") | |
| resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device) | |
| waveform_to_encode = resampler(original_waveform) | |
| logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}") | |
| else: | |
| logs.append("Waveform is already at the target sample rate (22kHz).") | |
| waveform_to_encode = original_waveform | |
| resample_end = time.time() | |
| logs.append(f"Resampling time: {resample_end - resample_start:.2f}s") | |
| # --- Prepare Resampled for Playback --- | |
| resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy()) | |
| logs.append("Prepared resampled audio for playback.") | |
| # 3. Prepare for Nemo Encoding | |
| # Nemo expects [batch, samples] format | |
| if waveform_to_encode.dim() == 2 and waveform_to_encode.shape[0] == 1: | |
| waveform_batch = waveform_to_encode # [1, samples] | |
| else: | |
| waveform_batch = waveform_to_encode.unsqueeze(0) # Add batch dimension | |
| waveform_batch = waveform_batch.to(DEVICE) | |
| # Calculate audio length for Nemo | |
| audio_len = torch.tensor([waveform_batch.shape[-1]], dtype=torch.int64).to(DEVICE) | |
| logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Audio length: {audio_len.item()}, Device: {DEVICE}") | |
| # 4. Encode Audio using Nemo | |
| logs.append("Encoding audio with Nemo codec...") | |
| encode_start = time.time() | |
| with torch.inference_mode(): | |
| encoded_tokens, tokens_len = nemo_codec.encode(audio=waveform_batch, audio_len=audio_len) | |
| encode_end = time.time() | |
| if encoded_tokens is None: | |
| log_msg = "Encoding failed: encoded_tokens is None" | |
| logs.append(log_msg) | |
| raise ValueError(log_msg) | |
| logs.append(f"Encoding complete. Time: {encode_end - encode_start:.2f}s") | |
| logs.append(f"Encoded tokens shape: {encoded_tokens.shape}, tokens_len: {tokens_len}") | |
| logs.append(f"Encoded tokens device: {encoded_tokens.device}") | |
| # Log some statistics about the tokens | |
| if encoded_tokens.dim() >= 2: | |
| logs.append(f"Number of codebooks: {encoded_tokens.shape[1] if encoded_tokens.dim() >= 3 else 'N/A'}") | |
| logs.append(f"Sequence length: {encoded_tokens.shape[-1]}") | |
| logs.append(f"Token range: [{encoded_tokens.min().item():.0f}, {encoded_tokens.max().item():.0f}]") | |
| # 5. Decode the Tokens using Nemo | |
| logs.append("Decoding the generated tokens with Nemo codec...") | |
| decode_start = time.time() | |
| with torch.inference_mode(): | |
| reconstructed_waveform, _ = nemo_codec.decode(tokens=encoded_tokens, tokens_len=tokens_len) | |
| decode_end = time.time() | |
| logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s") | |
| # 6. Prepare Reconstructed Audio for Playback | |
| # Output should be [batch, samples]. Move to CPU, remove batch dim, convert to NumPy. | |
| reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy() | |
| logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}") | |
| reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np) | |
| # 7. Calculate quality metrics | |
| original_for_comparison = waveform_to_encode.squeeze().numpy() | |
| if len(original_for_comparison) != len(reconstructed_audio_np): | |
| # Handle length differences (common with codecs) | |
| min_len = min(len(original_for_comparison), len(reconstructed_audio_np)) | |
| original_trimmed = original_for_comparison[:min_len] | |
| reconstructed_trimmed = reconstructed_audio_np[:min_len] | |
| # Simple MSE calculation | |
| mse = np.mean((original_trimmed - reconstructed_trimmed) ** 2) | |
| logs.append(f"Audio length difference: Original {len(original_for_comparison)}, Reconstructed {len(reconstructed_audio_np)}") | |
| logs.append(f"MSE (first {min_len} samples): {mse:.6f}") | |
| else: | |
| mse = np.mean((original_for_comparison - reconstructed_audio_np) ** 2) | |
| logs.append(f"MSE: {mse:.6f}") | |
| logs.append("\n--- Audio Processing Completed Successfully ---") | |
| logs.append(f"Compression ratio: ~{len(original_for_comparison) / (encoded_tokens.numel() if encoded_tokens.numel() > 0 else 1):.1f}:1") | |
| return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs) | |
| except Exception as e: | |
| logs.append("\n--- An Error Occurred ---") | |
| logs.append(f"Error Type: {type(e).__name__}") | |
| logs.append(f"Error Details: {e}") | |
| logs.append("\n--- Traceback ---") | |
| logs.append(traceback.format_exc()) | |
| return None, None, None, "\n".join(logs) | |
| # --- Gradio Interface --- | |
| DESCRIPTION = """ | |
| This app demonstrates the **NVIDIA Nemo Codec** model (`nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps`) used in Kani TTS. | |
| **How it works:** | |
| 1. Upload an audio file (wav, mp3, flac, etc.). | |
| 2. The audio will be automatically resampled to 22kHz if needed. | |
| 3. The 22kHz audio is encoded into discrete tokens by the Nemo codec. | |
| 4. These tokens are then decoded back into audio by the Nemo codec. | |
| 5. You can listen to the original, the 22kHz version (if resampled), and the final reconstructed audio. | |
| **Technical details:** | |
| - Sample rate: 22kHz | |
| - Compression: ~0.6kbps | |
| - Frame rate: 12.5fps | |
| - 4 codebook levels per frame | |
| **Note:** Processing happens locally. Larger files will take longer. If the input is stereo, only the first channel is processed. | |
| """ | |
| iface = gr.Interface( | |
| fn=process_audio, | |
| inputs=gr.Audio(type="filepath", label="Upload Audio File"), | |
| outputs=[ | |
| gr.Audio(label="Original Audio"), | |
| gr.Audio(label="Resampled Audio (22kHz Input to Nemo)"), | |
| gr.Audio(label="Reconstructed Audio (Output from Nemo Codec)"), | |
| gr.Textbox(label="Log Output", lines=20) | |
| ], | |
| title="NVIDIA Nemo Codec Demo (22kHz)", | |
| description=DESCRIPTION, | |
| examples=[ | |
| # later I might add some samples | |
| # ["examples/example1.wav"], | |
| # ["examples/example2.wav"], | |
| ], | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| if nemo_codec is None: | |
| print("Cannot launch Gradio interface because Nemo codec failed to load.") | |
| else: | |
| print("Launching Gradio Interface...") | |
| print(f"Model: {MODEL_NAME}") | |
| print(f"Target sample rate: {TARGET_SR} Hz") | |
| print(f"Device: {DEVICE}") | |
| iface.launch(share=True) | |