File size: 9,957 Bytes
c471c87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#!/usr/bin/env python3
"""
NVIDIA Nemo Codec Test - Gradio App
Equivalent to snac_test.py but for the NVIDIA Nemo codec used in Kani TTS based models.
Allows testing encode/decode cycles with the nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps model.
"""

import gradio as gr
import torch
import torchaudio
import torchaudio.transforms as T
import numpy as np
import traceback
import time

# Attempt to import Nemo
try:
    from nemo.collections.tts.models import AudioCodecModel
    from nemo.utils.nemo_logging import Logger

    # Suppress Nemo logging
    nemo_logger = Logger()
    nemo_logger.remove_stream_handlers()
    print("Nemo modules imported successfully.")
except ImportError as e:
    print(f"Error importing Nemo: {e}")
    raise ImportError("Could not import Nemo. Make sure 'nemo_toolkit[tts]' is installed correctly.") from e

# --- Configuration ---
TARGET_SR = 22050  # Nemo codec operates at 22kHz
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
print(f"Using device: {DEVICE}")

# --- Load Model (Load once globally) ---
nemo_codec = None
try:
    print(f"Loading Nemo codec model: {MODEL_NAME}...")
    start_time = time.time()
    nemo_codec = AudioCodecModel.from_pretrained(MODEL_NAME)
    nemo_codec = nemo_codec.to(DEVICE)
    nemo_codec.eval()  # Set model to evaluation mode
    end_time = time.time()
    print(f"Nemo codec loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.")
except Exception as e:
    print(f"FATAL: Error loading Nemo codec: {e}")
    print(traceback.format_exc())

# --- Main Processing Function ---
def process_audio(audio_filepath):
    """
    Loads, resamples, encodes, decodes audio using Nemo codec, and returns results.
    """
    if nemo_codec is None:
        return None, None, None, "Error: Nemo codec could not be loaded. Cannot process audio."

    if audio_filepath is None:
        return None, None, None, "Please upload an audio file."

    logs = ["--- Starting Audio Processing with Nemo Codec ---"]
    try:
        # 1. Load Audio
        logs.append(f"Loading audio file: {audio_filepath}")
        load_start = time.time()
        original_waveform, original_sr = torchaudio.load(audio_filepath)
        load_end = time.time()
        logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s")

        # Ensure float32
        original_waveform = original_waveform.to(dtype=torch.float32)

        # Handle multi-channel audio: Use the first channel
        if original_waveform.shape[0] > 1:
            logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.")
            original_waveform = original_waveform[0:1, :]  # Keep channel dim for consistency

        # --- Prepare Original for Playback ---
        original_audio_playback = (original_sr, original_waveform.squeeze().numpy())
        logs.append("Prepared original audio for playback.")

        # 2. Resample if necessary
        resample_start = time.time()
        if original_sr != TARGET_SR:
            logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...")
            resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device)
            waveform_to_encode = resampler(original_waveform)
            logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}")
        else:
            logs.append("Waveform is already at the target sample rate (22kHz).")
            waveform_to_encode = original_waveform
        resample_end = time.time()
        logs.append(f"Resampling time: {resample_end - resample_start:.2f}s")

        # --- Prepare Resampled for Playback ---
        resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy())
        logs.append("Prepared resampled audio for playback.")

        # 3. Prepare for Nemo Encoding
        # Nemo expects [batch, samples] format
        if waveform_to_encode.dim() == 2 and waveform_to_encode.shape[0] == 1:
            waveform_batch = waveform_to_encode  # [1, samples]
        else:
            waveform_batch = waveform_to_encode.unsqueeze(0)  # Add batch dimension

        waveform_batch = waveform_batch.to(DEVICE)

        # Calculate audio length for Nemo
        audio_len = torch.tensor([waveform_batch.shape[-1]], dtype=torch.int64).to(DEVICE)
        logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Audio length: {audio_len.item()}, Device: {DEVICE}")

        # 4. Encode Audio using Nemo
        logs.append("Encoding audio with Nemo codec...")
        encode_start = time.time()
        with torch.inference_mode():
            encoded_tokens, tokens_len = nemo_codec.encode(audio=waveform_batch, audio_len=audio_len)
        encode_end = time.time()

        if encoded_tokens is None:
            log_msg = "Encoding failed: encoded_tokens is None"
            logs.append(log_msg)
            raise ValueError(log_msg)

        logs.append(f"Encoding complete. Time: {encode_end - encode_start:.2f}s")
        logs.append(f"Encoded tokens shape: {encoded_tokens.shape}, tokens_len: {tokens_len}")
        logs.append(f"Encoded tokens device: {encoded_tokens.device}")

        # Log some statistics about the tokens
        if encoded_tokens.dim() >= 2:
            logs.append(f"Number of codebooks: {encoded_tokens.shape[1] if encoded_tokens.dim() >= 3 else 'N/A'}")
            logs.append(f"Sequence length: {encoded_tokens.shape[-1]}")
            logs.append(f"Token range: [{encoded_tokens.min().item():.0f}, {encoded_tokens.max().item():.0f}]")

        # 5. Decode the Tokens using Nemo
        logs.append("Decoding the generated tokens with Nemo codec...")
        decode_start = time.time()
        with torch.inference_mode():
            reconstructed_waveform, _ = nemo_codec.decode(tokens=encoded_tokens, tokens_len=tokens_len)
        decode_end = time.time()
        logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s")

        # 6. Prepare Reconstructed Audio for Playback
        # Output should be [batch, samples]. Move to CPU, remove batch dim, convert to NumPy.
        reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy()
        logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}")
        reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np)

        # 7. Calculate quality metrics
        original_for_comparison = waveform_to_encode.squeeze().numpy()
        if len(original_for_comparison) != len(reconstructed_audio_np):
            # Handle length differences (common with codecs)
            min_len = min(len(original_for_comparison), len(reconstructed_audio_np))
            original_trimmed = original_for_comparison[:min_len]
            reconstructed_trimmed = reconstructed_audio_np[:min_len]

            # Simple MSE calculation
            mse = np.mean((original_trimmed - reconstructed_trimmed) ** 2)
            logs.append(f"Audio length difference: Original {len(original_for_comparison)}, Reconstructed {len(reconstructed_audio_np)}")
            logs.append(f"MSE (first {min_len} samples): {mse:.6f}")
        else:
            mse = np.mean((original_for_comparison - reconstructed_audio_np) ** 2)
            logs.append(f"MSE: {mse:.6f}")

        logs.append("\n--- Audio Processing Completed Successfully ---")
        logs.append(f"Compression ratio: ~{len(original_for_comparison) / (encoded_tokens.numel() if encoded_tokens.numel() > 0 else 1):.1f}:1")

        return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs)

    except Exception as e:
        logs.append("\n--- An Error Occurred ---")
        logs.append(f"Error Type: {type(e).__name__}")
        logs.append(f"Error Details: {e}")
        logs.append("\n--- Traceback ---")
        logs.append(traceback.format_exc())
        return None, None, None, "\n".join(logs)

# --- Gradio Interface ---
DESCRIPTION = """
This app demonstrates the **NVIDIA Nemo Codec** model (`nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps`) used in Kani TTS.

**How it works:**
1. Upload an audio file (wav, mp3, flac, etc.).
2. The audio will be automatically resampled to 22kHz if needed.
3. The 22kHz audio is encoded into discrete tokens by the Nemo codec.
4. These tokens are then decoded back into audio by the Nemo codec.
5. You can listen to the original, the 22kHz version (if resampled), and the final reconstructed audio.

**Technical details:**
- Sample rate: 22kHz
- Compression: ~0.6kbps
- Frame rate: 12.5fps
- 4 codebook levels per frame

**Note:** Processing happens locally. Larger files will take longer. If the input is stereo, only the first channel is processed.
"""

iface = gr.Interface(
    fn=process_audio,
    inputs=gr.Audio(type="filepath", label="Upload Audio File"),
    outputs=[
        gr.Audio(label="Original Audio"),
        gr.Audio(label="Resampled Audio (22kHz Input to Nemo)"),
        gr.Audio(label="Reconstructed Audio (Output from Nemo Codec)"),
        gr.Textbox(label="Log Output", lines=20)
    ],
    title="NVIDIA Nemo Codec Demo (22kHz)",
    description=DESCRIPTION,
    examples=[
        # later I might add some samples
        # ["examples/example1.wav"],
        # ["examples/example2.wav"],
    ],
    cache_examples=False
)

if __name__ == "__main__":
    if nemo_codec is None:
        print("Cannot launch Gradio interface because Nemo codec failed to load.")
    else:
        print("Launching Gradio Interface...")
        print(f"Model: {MODEL_NAME}")
        print(f"Target sample rate: {TARGET_SR} Hz")
        print(f"Device: {DEVICE}")
        iface.launch(share=True)