| import torch |
| import torch.nn as nn |
| import torchaudio |
| from typing import Optional |
|
|
|
|
| class Spectrogram(nn.Module): |
| """ |
| Mel Spectrogram module with AmplitudeToDB conversion. |
| |
| Args: |
| sample_rate (int): Sample rate of the audio. |
| n_fft (int): Size of FFT. |
| win_length (Optional[int]): Window length. Defaults to n_fft. |
| win_length_ms (Optional[float]): Window length in milliseconds. Overrides win_length if provided. |
| hop_length (Optional[int]): Hop length. Defaults to win_length // 2. |
| hop_length_ms (Optional[float]): Hop length in milliseconds. Overrides hop_length if provided. |
| n_mels (int): Number of mel filterbanks. |
| f_min (float): Minimum frequency. |
| f_max (Optional[float]): Maximum frequency. |
| power (float): Power of the magnitude. |
| """ |
|
|
| def __init__( |
| self, |
| sample_rate: int = 32000, |
| n_fft: int = 4096, |
| win_length: Optional[int] = None, |
| win_length_ms: Optional[float] = None, |
| hop_length: Optional[int] = None, |
| hop_length_ms: Optional[float] = None, |
| n_mels: int = 128, |
| f_min: float = 0.0, |
| f_max: Optional[float] = None, |
| power: float = 2.0, |
| ): |
| super().__init__() |
|
|
| if win_length is None: |
| if win_length_ms is None: |
| win_length = n_fft |
| else: |
| win_length = int(sample_rate * win_length_ms / 1000) |
|
|
| if hop_length is None: |
| if hop_length_ms is None: |
| hop_length = win_length // 2 |
| else: |
| hop_length = int(sample_rate * hop_length_ms / 1000) |
|
|
| self.mel_spec = torchaudio.transforms.MelSpectrogram( |
| sample_rate=sample_rate, |
| n_fft=n_fft, |
| win_length=win_length, |
| hop_length=hop_length, |
| n_mels=n_mels, |
| f_min=f_min, |
| f_max=f_max, |
| power=power, |
| normalized=True, |
| ) |
| self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass. |
| |
| Args: |
| x (torch.Tensor): Input waveform [B, C, T] or [B, T]. |
| |
| Returns: |
| torch.Tensor: Log-Mel Spectrogram [B, C, F, T]. |
| """ |
| |
| |
| |
|
|
| spec = self.mel_spec(x) |
| spec = self.amplitude_to_db(spec) |
|
|
| return spec |
|
|