|
|
"""Forced alignment for word-level timestamps using Wav2Vec2.""" |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
START_OFFSET = 0.04 |
|
|
END_OFFSET = -0.04 |
|
|
|
|
|
|
|
|
def _get_device() -> str: |
|
|
"""Get best available device for non-transformers models.""" |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda" |
|
|
if torch.backends.mps.is_available(): |
|
|
return "mps" |
|
|
return "cpu" |
|
|
|
|
|
|
|
|
class ForcedAligner: |
|
|
"""Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2. |
|
|
|
|
|
Uses Viterbi trellis algorithm for optimal alignment path finding. |
|
|
""" |
|
|
|
|
|
_bundle = None |
|
|
_model = None |
|
|
_labels = None |
|
|
_dictionary = None |
|
|
|
|
|
@classmethod |
|
|
def get_instance(cls, device: str = "cuda"): |
|
|
"""Get or create the forced alignment model (singleton). |
|
|
|
|
|
Args: |
|
|
device: Device to run model on ("cuda" or "cpu") |
|
|
|
|
|
Returns: |
|
|
Tuple of (model, labels, dictionary) |
|
|
""" |
|
|
if cls._model is None: |
|
|
import torchaudio |
|
|
|
|
|
cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H |
|
|
cls._model = cls._bundle.get_model().to(device) |
|
|
cls._model.eval() |
|
|
cls._labels = cls._bundle.get_labels() |
|
|
cls._dictionary = {c: i for i, c in enumerate(cls._labels)} |
|
|
return cls._model, cls._labels, cls._dictionary |
|
|
|
|
|
@staticmethod |
|
|
def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor: |
|
|
"""Build trellis for forced alignment using forward algorithm. |
|
|
|
|
|
The trellis[t, j] represents the log probability of the best path that |
|
|
aligns the first j tokens to the first t frames. |
|
|
|
|
|
Args: |
|
|
emission: Log-softmax emission matrix of shape (num_frames, num_classes) |
|
|
tokens: List of target token indices |
|
|
blank_id: Index of the blank/CTC token (default 0) |
|
|
|
|
|
Returns: |
|
|
Trellis matrix of shape (num_frames + 1, num_tokens + 1) |
|
|
""" |
|
|
num_frames = emission.size(0) |
|
|
num_tokens = len(tokens) |
|
|
|
|
|
trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf")) |
|
|
trellis[0, 0] = 0 |
|
|
|
|
|
|
|
|
|
|
|
if num_tokens > 1: |
|
|
trellis[-num_tokens + 1:, 0] = float("inf") |
|
|
|
|
|
for t in range(num_frames): |
|
|
for j in range(num_tokens + 1): |
|
|
|
|
|
stay = trellis[t, j] + emission[t, blank_id] |
|
|
|
|
|
|
|
|
move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf") |
|
|
|
|
|
trellis[t + 1, j] = max(stay, move) |
|
|
|
|
|
return trellis |
|
|
|
|
|
@staticmethod |
|
|
def _backtrack( |
|
|
trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0 |
|
|
) -> list[tuple[int, float, float, float]]: |
|
|
"""Backtrack through trellis to find optimal forced monotonic alignment. |
|
|
|
|
|
Guarantees: |
|
|
- All tokens are emitted exactly once |
|
|
- Strictly monotonic: each token's frames come after previous token's |
|
|
- No frame skipping or token teleporting |
|
|
|
|
|
Returns list of (token_id, start_frame, end_frame, peak_frame) for each token. |
|
|
The peak_frame is the frame with highest emission probability for that token. |
|
|
""" |
|
|
num_frames = emission.size(0) |
|
|
num_tokens = len(tokens) |
|
|
|
|
|
if num_tokens == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
if trellis[num_frames, num_tokens] == -float("inf"): |
|
|
|
|
|
frames_per_token = num_frames / num_tokens |
|
|
return [ |
|
|
(tokens[i], i * frames_per_token, (i + 1) * frames_per_token, (i + 0.5) * frames_per_token) |
|
|
for i in range(num_tokens) |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)] |
|
|
|
|
|
t = num_frames |
|
|
j = num_tokens |
|
|
|
|
|
while t > 0 and j > 0: |
|
|
|
|
|
stay_score = trellis[t - 1, j] + emission[t - 1, blank_id] |
|
|
move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] |
|
|
|
|
|
if move_score >= stay_score: |
|
|
|
|
|
|
|
|
emit_prob = emission[t - 1, tokens[j - 1]].exp().item() |
|
|
token_frames[j - 1].insert(0, (t - 1, emit_prob)) |
|
|
j -= 1 |
|
|
|
|
|
t -= 1 |
|
|
|
|
|
|
|
|
while j > 0: |
|
|
token_frames[j - 1].insert(0, (0, 0.0)) |
|
|
j -= 1 |
|
|
|
|
|
|
|
|
token_spans: list[tuple[int, float, float, float]] = [] |
|
|
for token_idx, frames_with_scores in enumerate(token_frames): |
|
|
if not frames_with_scores: |
|
|
|
|
|
if token_spans: |
|
|
prev_end = token_spans[-1][2] |
|
|
frames_with_scores = [(int(prev_end), 0.0)] |
|
|
else: |
|
|
frames_with_scores = [(0, 0.0)] |
|
|
|
|
|
token_id = tokens[token_idx] |
|
|
frames = [f for f, _ in frames_with_scores] |
|
|
start_frame = float(min(frames)) |
|
|
end_frame = float(max(frames)) + 1.0 |
|
|
|
|
|
|
|
|
peak_frame, _ = max(frames_with_scores, key=lambda x: x[1]) |
|
|
|
|
|
token_spans.append((token_id, start_frame, end_frame, float(peak_frame))) |
|
|
|
|
|
return token_spans |
|
|
|
|
|
@classmethod |
|
|
def align( |
|
|
cls, |
|
|
audio: np.ndarray, |
|
|
text: str, |
|
|
sample_rate: int = 16000, |
|
|
_language: str = "eng", |
|
|
_batch_size: int = 16, |
|
|
) -> list[dict]: |
|
|
"""Align transcript to audio and return word-level timestamps. |
|
|
|
|
|
Uses Viterbi trellis algorithm for optimal forced alignment. |
|
|
|
|
|
Args: |
|
|
audio: Audio waveform as numpy array |
|
|
text: Transcript text to align |
|
|
sample_rate: Audio sample rate (default 16000) |
|
|
_language: ISO-639-3 language code (default "eng" for English, unused) |
|
|
_batch_size: Batch size for alignment model (unused) |
|
|
|
|
|
Returns: |
|
|
List of dicts with 'word', 'start', 'end' keys |
|
|
""" |
|
|
import torchaudio |
|
|
|
|
|
device = _get_device() |
|
|
model, _labels, dictionary = cls.get_instance(device) |
|
|
assert cls._bundle is not None and dictionary is not None |
|
|
|
|
|
|
|
|
if isinstance(audio, np.ndarray): |
|
|
waveform = torch.from_numpy(audio.copy()).float() |
|
|
else: |
|
|
waveform = audio.clone().float() |
|
|
|
|
|
|
|
|
if waveform.dim() == 1: |
|
|
waveform = waveform.unsqueeze(0) |
|
|
|
|
|
|
|
|
if sample_rate != cls._bundle.sample_rate: |
|
|
waveform = torchaudio.functional.resample( |
|
|
waveform, sample_rate, cls._bundle.sample_rate |
|
|
) |
|
|
|
|
|
waveform = waveform.to(device) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
emissions, _ = model(waveform) |
|
|
emissions = torch.log_softmax(emissions, dim=-1) |
|
|
|
|
|
emission = emissions[0].cpu() |
|
|
|
|
|
|
|
|
transcript = text.upper() |
|
|
|
|
|
|
|
|
tokens = [] |
|
|
for char in transcript: |
|
|
if char in dictionary: |
|
|
tokens.append(dictionary[char]) |
|
|
elif char == " ": |
|
|
tokens.append(dictionary.get("|", dictionary.get(" ", 0))) |
|
|
|
|
|
if not tokens: |
|
|
return [] |
|
|
|
|
|
|
|
|
trellis = cls._get_trellis(emission, tokens, blank_id=0) |
|
|
alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0) |
|
|
|
|
|
|
|
|
frame_duration = 320 / cls._bundle.sample_rate |
|
|
|
|
|
|
|
|
start_offset = START_OFFSET |
|
|
end_offset = END_OFFSET |
|
|
|
|
|
|
|
|
|
|
|
words = text.split() |
|
|
word_timestamps = [] |
|
|
first_char_peak = None |
|
|
last_char_peak = None |
|
|
word_idx = 0 |
|
|
separator_id = dictionary.get("|", dictionary.get(" ", 0)) |
|
|
|
|
|
for token_id, _start_frame, _end_frame, peak_frame in alignment_path: |
|
|
if token_id == separator_id: |
|
|
if ( |
|
|
first_char_peak is not None |
|
|
and last_char_peak is not None |
|
|
and word_idx < len(words) |
|
|
): |
|
|
|
|
|
start_time = max(0.0, first_char_peak * frame_duration - start_offset) |
|
|
end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset) |
|
|
word_timestamps.append( |
|
|
{ |
|
|
"word": words[word_idx], |
|
|
"start": start_time, |
|
|
"end": end_time, |
|
|
} |
|
|
) |
|
|
word_idx += 1 |
|
|
first_char_peak = None |
|
|
last_char_peak = None |
|
|
else: |
|
|
if first_char_peak is None: |
|
|
first_char_peak = peak_frame |
|
|
last_char_peak = peak_frame |
|
|
|
|
|
|
|
|
if ( |
|
|
first_char_peak is not None |
|
|
and last_char_peak is not None |
|
|
and word_idx < len(words) |
|
|
): |
|
|
start_time = max(0.0, first_char_peak * frame_duration - start_offset) |
|
|
end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset) |
|
|
word_timestamps.append( |
|
|
{ |
|
|
"word": words[word_idx], |
|
|
"start": start_time, |
|
|
"end": end_time, |
|
|
} |
|
|
) |
|
|
|
|
|
return word_timestamps |
|
|
|