Alexander4127 commited on
Commit
8b43d4b
·
verified ·
1 Parent(s): 34c3137

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. config.json +67 -0
  2. modeling_gigaam.py +1426 -0
  3. pytorch_model.bin +3 -0
  4. tokenizer.model +3 -0
config.json ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "gigaam",
3
+ "auto_map": {
4
+ "AutoConfig": "modeling_gigaam.GigaAMConfig",
5
+ "AutoModel": "modeling_gigaam.GigaAMModel"
6
+ },
7
+ "cfg": {
8
+ "model": {
9
+ "cfg": {
10
+ "model_class": "rnnt",
11
+ "sample_rate": 16000,
12
+ "preprocessor": {
13
+ "_target_": "modeling_gigaam.FeatureExtractor",
14
+ "sample_rate": 16000,
15
+ "features": 64,
16
+ "win_length": 320,
17
+ "hop_length": 160,
18
+ "mel_scale": "htk",
19
+ "n_fft": 320,
20
+ "mel_norm": null,
21
+ "center": false
22
+ },
23
+ "encoder": {
24
+ "_target_": "modeling_gigaam.ConformerEncoder",
25
+ "feat_in": 64,
26
+ "n_layers": 16,
27
+ "d_model": 768,
28
+ "subsampling_factor": 4,
29
+ "ff_expansion_factor": 4,
30
+ "self_attention_model": "rotary",
31
+ "pos_emb_max_len": 5000,
32
+ "n_heads": 16,
33
+ "conv_kernel_size": 5,
34
+ "flash_attn": false,
35
+ "subs_kernel_size": 5,
36
+ "subsampling": "conv1d",
37
+ "conv_norm_type": "layer_norm"
38
+ },
39
+ "head": {
40
+ "_target_": "modeling_gigaam.RNNTHead",
41
+ "decoder": {
42
+ "pred_hidden": 320,
43
+ "pred_rnn_layers": 1,
44
+ "num_classes": 1025
45
+ },
46
+ "joint": {
47
+ "enc_hidden": 768,
48
+ "pred_hidden": 320,
49
+ "joint_hidden": 320,
50
+ "num_classes": 1025
51
+ }
52
+ },
53
+ "decoding": {
54
+ "_target_": "modeling_gigaam.RNNTGreedyDecoding",
55
+ "vocabulary": null,
56
+ "model_path": "tokenizer.model"
57
+ },
58
+ "model_name": "v3_e2e_rnnt",
59
+ "hashes": {
60
+ "model": "72e2a9b5c7caad963b2bbfd2f298c252",
61
+ "tokenizer": "3b3bf8370e882885d79731592fc99f98"
62
+ }
63
+ },
64
+ "_target_": "modeling_gigaam.GigaAMASR"
65
+ }
66
+ }
67
+ }
modeling_gigaam.py ADDED
@@ -0,0 +1,1426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import math
4
+ import os
5
+ import sys
6
+ import warnings
7
+ from abc import ABC, abstractmethod
8
+ from pathlib import Path
9
+ from subprocess import CalledProcessError, run
10
+ from typing import Any, Dict, List, Optional, Tuple, Union
11
+
12
+ import hydra
13
+ import numpy as np
14
+ import omegaconf
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+ from hydra.utils import instantiate
19
+ from sentencepiece import SentencePieceProcessor
20
+ from torch import Tensor, nn
21
+ from torch.jit import TracerWarning
22
+ from transformers import PretrainedConfig, PreTrainedModel
23
+ from transformers.utils import cached_file
24
+
25
+ DIR_NAME = os.path.dirname(os.path.abspath(__file__))
26
+ sys.path.append(DIR_NAME) # enable using modules through modeling_gigaam.<module_name>
27
+
28
+
29
+ IMPORT_FLASH = False
30
+ SAMPLE_RATE = 16000
31
+ LONGFORM_THRESHOLD = 25 * SAMPLE_RATE
32
+ _PIPELINE = None
33
+
34
+
35
+ ### preprocess ###
36
+
37
+
38
+ def load_audio(audio_path: str, sample_rate: int = SAMPLE_RATE) -> Tensor:
39
+ """
40
+ Load an audio file and resample it to the specified sample rate.
41
+ """
42
+ cmd = [
43
+ "ffmpeg",
44
+ "-nostdin",
45
+ "-threads",
46
+ "0",
47
+ "-i",
48
+ audio_path,
49
+ "-f",
50
+ "s16le",
51
+ "-ac",
52
+ "1",
53
+ "-acodec",
54
+ "pcm_s16le",
55
+ "-ar",
56
+ str(sample_rate),
57
+ "-",
58
+ ]
59
+ try:
60
+ audio = run(cmd, capture_output=True, check=True).stdout
61
+ except CalledProcessError as exc:
62
+ raise RuntimeError("Failed to load audio") from exc
63
+
64
+ with warnings.catch_warnings():
65
+ warnings.simplefilter("ignore", category=UserWarning)
66
+ return torch.frombuffer(audio, dtype=torch.int16).float() / 32768.0
67
+
68
+
69
+ class SpecScaler(nn.Module):
70
+ """
71
+ Module that applies logarithmic scaling to spectrogram values.
72
+ This module clamps the input values within a certain range and then applies a natural logarithm.
73
+ """
74
+
75
+ def forward(self, x: Tensor) -> Tensor:
76
+ return torch.log(x.clamp_(1e-9, 1e9))
77
+
78
+
79
+ class FeatureExtractor(nn.Module):
80
+ """
81
+ Module for extracting Log-mel spectrogram features from raw audio signals.
82
+ This module uses Torchaudio's MelSpectrogram transform to extract features
83
+ and applies logarithmic scaling.
84
+ """
85
+
86
+ def __init__(self, sample_rate: int, features: int, **kwargs):
87
+ super().__init__()
88
+ self.hop_length = kwargs.get("hop_length", sample_rate // 100)
89
+ self.win_length = kwargs.get("win_length", sample_rate // 40)
90
+ self.n_fft = kwargs.get("n_fft", sample_rate // 40)
91
+ self.center = kwargs.get("center", True)
92
+ self.featurizer = nn.Sequential(
93
+ torchaudio.transforms.MelSpectrogram(
94
+ sample_rate=sample_rate,
95
+ n_mels=features,
96
+ win_length=self.win_length,
97
+ hop_length=self.hop_length,
98
+ n_fft=self.n_fft,
99
+ center=self.center,
100
+ ),
101
+ SpecScaler(),
102
+ )
103
+
104
+ def out_len(self, input_lengths: Tensor) -> Tensor:
105
+ """
106
+ Calculates the output length after the feature extraction process.
107
+ """
108
+ if self.center:
109
+ return (
110
+ input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
111
+ )
112
+ else:
113
+ return (
114
+ (input_lengths - self.win_length)
115
+ .div(self.hop_length, rounding_mode="floor")
116
+ .add(1)
117
+ .long()
118
+ )
119
+
120
+ def forward(self, input_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]:
121
+ """
122
+ Extract Log-mel spectrogram features from the input audio signal.
123
+ """
124
+ return self.featurizer(input_signal), self.out_len(length)
125
+
126
+
127
+ ### utils ###
128
+
129
+
130
+ def onnx_converter(
131
+ model_name: str,
132
+ module: torch.nn.Module,
133
+ out_dir: str,
134
+ inputs: Optional[Tuple[Tensor, ...]] = None,
135
+ input_names: Optional[List[str]] = None,
136
+ output_names: Optional[List[str]] = None,
137
+ dynamic_axes: Optional[
138
+ Union[Dict[str, List[int]], Dict[str, Dict[int, str]]]
139
+ ] = None,
140
+ opset_version: int = 17,
141
+ ):
142
+ if inputs is None:
143
+ inputs = module.input_example() # type: ignore[operator]
144
+ if input_names is None:
145
+ input_names = module.input_names() # type: ignore[operator]
146
+ if output_names is None:
147
+ output_names = module.output_names() # type: ignore[operator]
148
+
149
+ Path(out_dir).mkdir(exist_ok=True, parents=True)
150
+ out_path = str(Path(out_dir) / f"{model_name}.onnx")
151
+ saved_dtype = next(module.parameters()).dtype
152
+ with warnings.catch_warnings():
153
+ warnings.simplefilter("ignore", category=UserWarning)
154
+ warnings.simplefilter("ignore", category=TracerWarning)
155
+ torch.onnx.export(
156
+ module.to(torch.float32),
157
+ inputs,
158
+ out_path,
159
+ input_names=input_names,
160
+ output_names=output_names,
161
+ dynamic_axes=dynamic_axes,
162
+ opset_version=opset_version,
163
+ )
164
+ print(f"Succesfully ported onnx {model_name} to {out_path}.")
165
+ module.to(saved_dtype)
166
+
167
+
168
+ def format_time(seconds: float) -> str:
169
+ """
170
+ Formats time in seconds to HH:MM:SS:mm format.
171
+ """
172
+ hours = int(seconds // 3600)
173
+ minutes = int((seconds % 3600) // 60)
174
+ seconds = seconds % 60
175
+ full_seconds = int(seconds)
176
+ milliseconds = int((seconds - full_seconds) * 100)
177
+
178
+ if hours > 0:
179
+ return f"{hours:02}:{minutes:02}:{full_seconds:02}:{milliseconds:02}"
180
+ return f"{minutes:02}:{full_seconds:02}:{milliseconds:02}"
181
+
182
+
183
+ def rtt_half(x: Tensor) -> Tensor:
184
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
185
+ return torch.cat([-x2, x1], dim=x1.ndim - 1)
186
+
187
+
188
+ def apply_rotary_pos_emb(
189
+ q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, offset: int = 0
190
+ ) -> Tuple[Tensor, Tensor]:
191
+ """
192
+ Applies Rotary Position Embeddings to query and key tensors.
193
+ """
194
+ cos, sin = (
195
+ cos[offset : q.shape[0] + offset, ...],
196
+ sin[offset : q.shape[0] + offset, ...],
197
+ )
198
+ return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin)
199
+
200
+
201
+ def _normalize_device(device: Optional[Union[str, torch.device]]) -> torch.device:
202
+ """Normalize device parameter to torch.device."""
203
+ if device is None:
204
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
205
+ return torch.device(device_str)
206
+ if isinstance(device, str):
207
+ return torch.device(device)
208
+ return device
209
+
210
+
211
+ def download_short_audio():
212
+ """Download test audio file if not exists"""
213
+ audio_file = "example.wav"
214
+ if not os.path.exists(audio_file):
215
+ os.system(
216
+ 'wget -O example.wav "https://cdn.chatwm.opensmodel.sberdevices.ru/GigaAM/example.wav"'
217
+ )
218
+ assert os.path.exists(audio_file), "Short audio file not found"
219
+ return audio_file
220
+
221
+
222
+ def download_long_audio():
223
+ """Download test audio file if not exists"""
224
+ audio_file = "long_example.wav"
225
+ if not os.path.exists(audio_file):
226
+ os.system(
227
+ 'wget -O long_example.wav "https://cdn.chatwm.opensmodel.sberdevices.ru/GigaAM/long_example.wav"'
228
+ )
229
+ assert os.path.exists(audio_file), "Long audio file not found"
230
+ return audio_file
231
+
232
+
233
+ class AudioDataset(torch.utils.data.Dataset):
234
+ """
235
+ Helper class for creating batched inputs
236
+ """
237
+
238
+ def __init__(self, lst: List[Union[str, np.ndarray, torch.Tensor]]):
239
+ assert isinstance(
240
+ lst[0], (str, np.ndarray, torch.Tensor)
241
+ ), f"Unexpected dtype: {type(lst[0])}"
242
+ self.lst = lst
243
+
244
+ def __len__(self):
245
+ return len(self.lst)
246
+
247
+ def __getitem__(self, idx):
248
+ item = self.lst[idx]
249
+ if isinstance(item, str):
250
+ wav_tns = load_audio(item)
251
+ elif isinstance(item, np.ndarray):
252
+ wav_tns = torch.from_numpy(item)
253
+ elif isinstance(item, torch.Tensor):
254
+ wav_tns = item
255
+ else:
256
+ raise RuntimeError(f"Unexpected sample type: {type(item)} at idx={idx}")
257
+ return wav_tns
258
+
259
+ @staticmethod
260
+ def collate(wavs):
261
+ lengths = torch.tensor([len(wav) for wav in wavs])
262
+ max_len = lengths.max().item()
263
+ wav_tns = torch.zeros(len(wavs), max_len, dtype=wavs[0].dtype)
264
+ for idx, wav in enumerate(wavs):
265
+ wav_tns[idx, : wav.shape[-1]] = wav.squeeze()
266
+ return wav_tns, lengths
267
+
268
+
269
+
270
+ ### vad utils ###
271
+
272
+
273
+ def get_pipeline(device: torch.device):
274
+ """
275
+ Retrieves a PyAnnote voice activity detection pipeline and move it to the specified device.
276
+ The pipeline is loaded only once and reused across subsequent calls.
277
+ It requires the Hugging Face API token to be set in the HF_TOKEN environment variable.
278
+ """
279
+ global _PIPELINE
280
+ if _PIPELINE is not None:
281
+ return _PIPELINE.to(device)
282
+
283
+ from pyannote.audio import Model
284
+ from pyannote.audio.pipelines import VoiceActivityDetection
285
+
286
+ try:
287
+ hf_token = os.environ["HF_TOKEN"]
288
+ except KeyError as exc:
289
+ raise ValueError("HF_TOKEN environment variable is not set") from exc
290
+
291
+ model = Model.from_pretrained("pyannote/segmentation-3.0", use_auth_token=hf_token)
292
+ _PIPELINE = VoiceActivityDetection(segmentation=model)
293
+ _PIPELINE.instantiate({"min_duration_on": 0.0, "min_duration_off": 0.0})
294
+
295
+ return _PIPELINE.to(device)
296
+
297
+
298
+ def segment_audio_file(
299
+ wav_file: str,
300
+ sr: int,
301
+ max_duration: float = 22.0,
302
+ min_duration: float = 15.0,
303
+ strict_limit_duration: float = 30.0,
304
+ new_chunk_threshold: float = 0.2,
305
+ device: torch.device = torch.device("cpu"),
306
+ ) -> Tuple[List[torch.Tensor], List[Tuple[float, float]]]:
307
+ """
308
+ Segments an audio waveform into smaller chunks based on speech activity.
309
+ The segmentation is performed using a PyAnnote voice activity detection pipeline.
310
+ """
311
+
312
+ audio = load_audio(wav_file)
313
+ pipeline = get_pipeline(device)
314
+ sad_segments = pipeline(wav_file)
315
+
316
+ segments: List[torch.Tensor] = []
317
+ curr_duration = 0.0
318
+ curr_start = 0.0
319
+ curr_end = 0.0
320
+ boundaries: List[Tuple[float, float]] = []
321
+
322
+ def _update_segments(curr_start: float, curr_end: float, curr_duration: float):
323
+ if curr_duration > strict_limit_duration:
324
+ max_segments = int(curr_duration / strict_limit_duration) + 1
325
+ segment_duration = curr_duration / max_segments
326
+ curr_end = curr_start + segment_duration
327
+ for _ in range(max_segments - 1):
328
+ segments.append(audio[int(curr_start * sr) : int(curr_end * sr)])
329
+ boundaries.append((curr_start, curr_end))
330
+ curr_start = curr_end
331
+ curr_end += segment_duration
332
+ segments.append(audio[int(curr_start * sr) : int(curr_end * sr)])
333
+ boundaries.append((curr_start, curr_end))
334
+
335
+ # Concat segments from pipeline into chunks for asr according to max/min duration
336
+ # Segments longer than strict_limit_duration are splitted manually
337
+ for segment in sad_segments.get_timeline().support():
338
+ start = max(0, segment.start)
339
+ end = min(audio.shape[0] / sr, segment.end)
340
+ if curr_duration > new_chunk_threshold and (
341
+ curr_duration + (end - curr_end) > max_duration
342
+ or curr_duration > min_duration
343
+ ):
344
+ _update_segments(curr_start, curr_end, curr_duration)
345
+ curr_start = start
346
+ curr_end = end
347
+ curr_duration = curr_end - curr_start
348
+
349
+ if curr_duration > new_chunk_threshold:
350
+ _update_segments(curr_start, curr_end, curr_duration)
351
+
352
+ return segments, boundaries
353
+
354
+
355
+ ### encoder ###
356
+
357
+
358
+
359
+ class StridingSubsampling(nn.Module):
360
+ """
361
+ Strided Subsampling layer used to reduce the sequence length.
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ subsampling: str,
367
+ kernel_size: int,
368
+ subsampling_factor: int,
369
+ feat_in: int,
370
+ feat_out: int,
371
+ conv_channels: int,
372
+ ):
373
+ super().__init__()
374
+ self.subsampling_type = subsampling
375
+ assert self.subsampling_type in ["conv1d", "conv2d"]
376
+ self._sampling_num = int(math.log(subsampling_factor, 2))
377
+ self._stride = 2
378
+ self._kernel_size = kernel_size
379
+ self._padding = (self._kernel_size - 1) // 2
380
+
381
+ layers: List[nn.Module] = []
382
+ in_channels = 1 if self.subsampling_type == "conv2d" else feat_in
383
+ subs_conv_class = (
384
+ torch.nn.Conv2d if self.subsampling_type == "conv2d" else torch.nn.Conv1d
385
+ )
386
+ for _ in range(self._sampling_num):
387
+ layers.append(
388
+ subs_conv_class(
389
+ in_channels=in_channels,
390
+ out_channels=conv_channels,
391
+ kernel_size=self._kernel_size,
392
+ stride=self._stride,
393
+ padding=self._padding,
394
+ )
395
+ )
396
+ layers.append(nn.ReLU())
397
+ in_channels = conv_channels
398
+
399
+ out_length = self.calc_output_length(torch.tensor(feat_in))
400
+ if self.subsampling_type == "conv2d":
401
+ self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
402
+ self.conv = torch.nn.Sequential(*layers)
403
+
404
+ def calc_output_length(self, lengths: Tensor) -> Tensor:
405
+ """
406
+ Calculates the output length after applying the subsampling.
407
+ """
408
+ lengths = lengths.to(torch.float)
409
+ add_pad = 2 * self._padding - self._kernel_size
410
+ for _ in range(self._sampling_num):
411
+ lengths = torch.div(lengths + add_pad, self._stride) + 1.0
412
+ lengths = torch.floor(lengths)
413
+ return lengths.to(dtype=torch.int)
414
+
415
+ def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Tensor]:
416
+ if self.subsampling_type == "conv2d":
417
+ x = self.conv(x.unsqueeze(1))
418
+ b, _, t, _ = x.size()
419
+ x = self.out(x.transpose(1, 2).reshape(b, t, -1))
420
+ else:
421
+ x = self.conv(x.transpose(1, 2)).transpose(1, 2)
422
+ return x, self.calc_output_length(lengths)
423
+
424
+
425
+ class MultiHeadAttention(nn.Module, ABC):
426
+ """
427
+ Base class of Multi-Head Attention Mechanisms.
428
+ """
429
+
430
+ def __init__(
431
+ self, n_head: int, n_feat: int, flash_attn=False, torch_sdpa_attn=False
432
+ ):
433
+ super().__init__()
434
+ assert n_feat % n_head == 0
435
+ self.d_k = n_feat // n_head
436
+ self.h = n_head
437
+ self.linear_q = nn.Linear(n_feat, n_feat)
438
+ self.linear_k = nn.Linear(n_feat, n_feat)
439
+ self.linear_v = nn.Linear(n_feat, n_feat)
440
+ self.linear_out = nn.Linear(n_feat, n_feat)
441
+ self.flash_attn = flash_attn
442
+ self.torch_sdpa_attn = torch_sdpa_attn
443
+ if self.flash_attn and not IMPORT_FLASH:
444
+ raise RuntimeError(
445
+ f"flash_attn_func was imported with err {IMPORT_FLASH_ERR}. "
446
+ "Please install flash_attn or use --no_flash flag. "
447
+ "If you have already done this, "
448
+ "--force-reinstall flag might be useful"
449
+ )
450
+
451
+ def forward_qkv(
452
+ self, query: Tensor, key: Tensor, value: Tensor
453
+ ) -> Tuple[Tensor, Tensor, Tensor]:
454
+ """
455
+ Projects the inputs into queries, keys, and values for multi-head attention.
456
+ """
457
+ b = query.size(0)
458
+ q = self.linear_q(query).view(b, -1, self.h, self.d_k)
459
+ k = self.linear_k(key).view(b, -1, self.h, self.d_k)
460
+ v = self.linear_v(value).view(b, -1, self.h, self.d_k)
461
+ if self.flash_attn:
462
+ return q, k, v
463
+ return q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
464
+
465
+ def forward_attention(
466
+ self, value: Tensor, scores: Tensor, mask: Optional[Tensor]
467
+ ) -> Tensor:
468
+ """
469
+ Computes the scaled dot-product attention given the projected values and scores.
470
+ """
471
+ b = value.size(0)
472
+ if mask is not None:
473
+ mask = mask.unsqueeze(1)
474
+ scores = scores.masked_fill(mask, -10000.0)
475
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
476
+ else:
477
+ attn = torch.softmax(scores, dim=-1)
478
+ x = torch.matmul(attn, value)
479
+ x = x.transpose(1, 2).reshape(b, -1, self.h * self.d_k)
480
+ return self.linear_out(x)
481
+
482
+
483
+ class RelPositionMultiHeadAttention(MultiHeadAttention):
484
+ """
485
+ Relative Position Multi-Head Attention module.
486
+ """
487
+
488
+ def __init__(self, n_head: int, n_feat: int):
489
+ super().__init__(n_head, n_feat)
490
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
491
+ self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
492
+ self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
493
+
494
+ def rel_shift(self, x: Tensor) -> Tensor:
495
+ b, h, qlen, pos_len = x.size()
496
+ x = torch.nn.functional.pad(x, pad=(1, 0))
497
+ x = x.view(b, h, -1, qlen)
498
+ return x[:, :, 1:].view(b, h, qlen, pos_len)
499
+
500
+ def forward(
501
+ self,
502
+ query: Tensor,
503
+ key: Tensor,
504
+ value: Tensor,
505
+ pos_emb: Tensor,
506
+ mask: Optional[Tensor] = None,
507
+ ) -> Tensor:
508
+ q, k, v = self.forward_qkv(query, key, value)
509
+ q = q.transpose(1, 2)
510
+ p = self.linear_pos(pos_emb)
511
+ p = p.view(pos_emb.shape[0], -1, self.h, self.d_k).transpose(1, 2)
512
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
513
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
514
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
515
+ matrix_bd = self.rel_shift(matrix_bd)
516
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
517
+ matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
518
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
519
+ return self.forward_attention(v, scores, mask)
520
+
521
+
522
+ class RotaryPositionMultiHeadAttention(MultiHeadAttention):
523
+ """
524
+ Rotary Position Multi-Head Attention module.
525
+ """
526
+
527
+ def forward(
528
+ self,
529
+ query: Tensor,
530
+ key: Tensor,
531
+ value: Tensor,
532
+ pos_emb: List[Tensor],
533
+ mask: Optional[Tensor] = None,
534
+ ) -> Tensor:
535
+ b, t, _ = value.size()
536
+ query = query.transpose(0, 1).view(t, b, self.h, self.d_k)
537
+ key = key.transpose(0, 1).view(t, b, self.h, self.d_k)
538
+ value = value.transpose(0, 1).view(t, b, self.h, self.d_k)
539
+
540
+ cos, sin = pos_emb
541
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
542
+
543
+ q, k, v = self.forward_qkv(
544
+ query.view(t, b, self.h * self.d_k).transpose(0, 1),
545
+ key.view(t, b, self.h * self.d_k).transpose(0, 1),
546
+ value.view(t, b, self.h * self.d_k).transpose(0, 1),
547
+ )
548
+
549
+ if not self.flash_attn and not self.torch_sdpa_attn:
550
+ scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k))
551
+ return self.forward_attention(v, scores, mask)
552
+ elif self.flash_attn:
553
+ if mask is None:
554
+ scores = flash_attn_func(q, k, v)
555
+ else:
556
+ scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k)
557
+ scores = scores.view(b, -1, self.h * self.d_k)
558
+ return self.linear_out(scores)
559
+ else:
560
+ attn_mask = None if mask is None else ~mask.unsqueeze(1)
561
+ attn_output = F.scaled_dot_product_attention(
562
+ q,
563
+ k,
564
+ v,
565
+ attn_mask=attn_mask,
566
+ )
567
+ attn_output = attn_output.transpose(1, 2).reshape(b, t, self.h * self.d_k)
568
+ return self.linear_out(attn_output)
569
+
570
+
571
+ class PositionalEncoding(nn.Module, ABC):
572
+ """
573
+ Base class of Positional Encodings.
574
+ """
575
+
576
+ def __init__(self, dim: int, base: int):
577
+ super().__init__()
578
+ self.dim = dim
579
+ self.base = base
580
+
581
+ @abstractmethod
582
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
583
+ pass
584
+
585
+ def extend_pe(self, length: int, device: torch.device):
586
+ """
587
+ Extends the positional encoding buffer to process longer sequences.
588
+ """
589
+ pe = self.create_pe(length, device)
590
+ if pe is None:
591
+ return
592
+ if hasattr(self, "pe"):
593
+ self.pe = pe
594
+ else:
595
+ self.register_buffer("pe", pe, persistent=False)
596
+
597
+
598
+ class RelPositionalEmbedding(PositionalEncoding):
599
+ """
600
+ Relative Positional Embedding module.
601
+ """
602
+
603
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
604
+ """
605
+ Creates the relative positional encoding matrix.
606
+ """
607
+ if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1:
608
+ return None
609
+ positions = torch.arange(length - 1, -length, -1, device=device).unsqueeze(1)
610
+ pos_length = positions.size(0)
611
+ pe = torch.zeros(pos_length, self.dim, device=positions.device)
612
+ div_term = torch.exp(
613
+ torch.arange(0, self.dim, 2, device=pe.device)
614
+ * -(math.log(10000.0) / self.dim)
615
+ )
616
+ pe[:, 0::2] = torch.sin(positions * div_term)
617
+ pe[:, 1::2] = torch.cos(positions * div_term)
618
+ return pe.unsqueeze(0)
619
+
620
+ def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
621
+ input_len = x.size(1)
622
+ center_pos = self.pe.size(1) // 2 + 1
623
+ start_pos = center_pos - input_len
624
+ end_pos = center_pos + input_len - 1
625
+ return x, self.pe[:, start_pos:end_pos]
626
+
627
+
628
+ class RotaryPositionalEmbedding(PositionalEncoding):
629
+ """
630
+ Rotary Positional Embedding module.
631
+ """
632
+
633
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
634
+ """
635
+ Creates or extends the rotary positional encoding matrix.
636
+ """
637
+ if hasattr(self, "pe") and self.pe.size(0) >= 2 * length:
638
+ return None
639
+ positions = torch.arange(0, length, dtype=torch.float32, device=device)
640
+ inv_freq = 1.0 / (
641
+ self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
642
+ )
643
+ t = torch.arange(length, device=positions.device).type_as(inv_freq)
644
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
645
+ emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
646
+ return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
647
+
648
+ def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]:
649
+ cos_emb = self.pe[0 : x.shape[1]]
650
+ half_pe = self.pe.shape[0] // 2
651
+ sin_emb = self.pe[half_pe : half_pe + x.shape[1]]
652
+ return x, [cos_emb, sin_emb]
653
+
654
+
655
+ class ConformerConvolution(nn.Module):
656
+ """
657
+ Conformer Convolution module.
658
+ """
659
+
660
+ def __init__(
661
+ self,
662
+ d_model: int,
663
+ kernel_size: int,
664
+ norm_type: str,
665
+ ):
666
+ super().__init__()
667
+ assert (kernel_size - 1) % 2 == 0
668
+ assert norm_type in ["batch_norm", "layer_norm"]
669
+ self.norm_type = norm_type
670
+ self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)
671
+ self.depthwise_conv = nn.Conv1d(
672
+ in_channels=d_model,
673
+ out_channels=d_model,
674
+ kernel_size=kernel_size,
675
+ padding=(kernel_size - 1) // 2,
676
+ groups=d_model,
677
+ bias=True,
678
+ )
679
+ self.batch_norm = (
680
+ nn.BatchNorm1d(d_model)
681
+ if norm_type == "batch_norm"
682
+ else nn.LayerNorm(d_model)
683
+ )
684
+ self.activation = nn.SiLU()
685
+ self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
686
+
687
+ def forward(self, x: Tensor, pad_mask: Optional[Tensor] = None) -> Tensor:
688
+ x = x.transpose(1, 2)
689
+ x = self.pointwise_conv1(x)
690
+ x = nn.functional.glu(x, dim=1)
691
+ if pad_mask is not None:
692
+ x = x.masked_fill(pad_mask.unsqueeze(1), 0.0)
693
+ x = self.depthwise_conv(x)
694
+ if self.norm_type == "batch_norm":
695
+ x = self.batch_norm(x)
696
+ else:
697
+ x = self.batch_norm(x.transpose(1, 2)).transpose(1, 2)
698
+ x = self.activation(x)
699
+ x = self.pointwise_conv2(x)
700
+ return x.transpose(1, 2)
701
+
702
+
703
+ class ConformerFeedForward(nn.Module):
704
+ """
705
+ Conformer Feed Forward module.
706
+ """
707
+
708
+ def __init__(self, d_model: int, d_ff: int, use_bias=True):
709
+ super().__init__()
710
+ self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias)
711
+ self.activation = nn.SiLU()
712
+ self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias)
713
+
714
+ def forward(self, x: Tensor) -> Tensor:
715
+ return self.linear2(self.activation(self.linear1(x)))
716
+
717
+
718
+ class ConformerLayer(nn.Module):
719
+ """
720
+ Conformer Layer module.
721
+ This module combines several submodules including feed forward networks,
722
+ depthwise separable convolution, and multi-head self-attention
723
+ to form a single Conformer block.
724
+ """
725
+
726
+ def __init__(
727
+ self,
728
+ d_model: int,
729
+ d_ff: int,
730
+ self_attention_model: str,
731
+ n_heads: int = 16,
732
+ conv_norm_type: str = "batch_norm",
733
+ conv_kernel_size: int = 31,
734
+ flash_attn: bool = False,
735
+ ):
736
+ super().__init__()
737
+ self.fc_factor = 0.5
738
+ self.norm_feed_forward1 = nn.LayerNorm(d_model)
739
+ self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff)
740
+ self.norm_conv = nn.LayerNorm(d_model)
741
+ self.conv = ConformerConvolution(
742
+ d_model=d_model,
743
+ kernel_size=conv_kernel_size,
744
+ norm_type=conv_norm_type,
745
+ )
746
+ self.norm_self_att = nn.LayerNorm(d_model)
747
+ if self_attention_model == "rotary":
748
+ self.self_attn: nn.Module = RotaryPositionMultiHeadAttention(
749
+ n_head=n_heads,
750
+ n_feat=d_model,
751
+ flash_attn=flash_attn,
752
+ torch_sdpa_attn=not flash_attn,
753
+ )
754
+ else:
755
+ assert not flash_attn, "Not supported flash_attn for rel_pos"
756
+ self.self_attn = RelPositionMultiHeadAttention(
757
+ n_head=n_heads,
758
+ n_feat=d_model,
759
+ )
760
+ self.norm_feed_forward2 = nn.LayerNorm(d_model)
761
+ self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff)
762
+ self.norm_out = nn.LayerNorm(d_model)
763
+
764
+ def forward(
765
+ self,
766
+ x: Tensor,
767
+ pos_emb: Union[Tensor, List[Tensor]],
768
+ att_mask: Optional[Tensor] = None,
769
+ pad_mask: Optional[Tensor] = None,
770
+ ) -> Tensor:
771
+ residual = x
772
+ x = self.norm_feed_forward1(x)
773
+ x = self.feed_forward1(x)
774
+ residual = residual + x * self.fc_factor
775
+
776
+ x = self.norm_self_att(residual)
777
+ x = self.self_attn(x, x, x, pos_emb, mask=att_mask)
778
+ residual = residual + x
779
+
780
+ x = self.norm_conv(residual)
781
+ x = self.conv(x, pad_mask=pad_mask)
782
+ residual = residual + x
783
+
784
+ x = self.norm_feed_forward2(residual)
785
+ x = self.feed_forward2(x)
786
+ residual = residual + x * self.fc_factor
787
+
788
+ x = self.norm_out(residual)
789
+ return x
790
+
791
+
792
+ class ConformerEncoder(nn.Module):
793
+ """
794
+ Conformer Encoder module.
795
+ This module encapsulates the entire Conformer encoder architecture,
796
+ consisting of a StridingSubsampling layer, positional embeddings, and
797
+ a stack of Conformer Layers.
798
+ It serves as the main component responsible for processing speech features.
799
+ """
800
+
801
+ def __init__(
802
+ self,
803
+ feat_in: int = 64,
804
+ n_layers: int = 16,
805
+ d_model: int = 768,
806
+ subsampling: str = "conv2d",
807
+ subs_kernel_size: int = 3,
808
+ subsampling_factor: int = 4,
809
+ ff_expansion_factor: int = 4,
810
+ self_attention_model: str = "rotary",
811
+ n_heads: int = 16,
812
+ pos_emb_max_len: int = 5000,
813
+ conv_norm_type: str = "batch_norm",
814
+ conv_kernel_size: int = 31,
815
+ flash_attn: bool = False,
816
+ ):
817
+ super().__init__()
818
+ self.feat_in = feat_in
819
+ assert self_attention_model in [
820
+ "rotary",
821
+ "rel_pos",
822
+ ], f"Not supported attn = {self_attention_model}"
823
+
824
+ self.pre_encode = StridingSubsampling(
825
+ subsampling=subsampling,
826
+ kernel_size=subs_kernel_size,
827
+ subsampling_factor=subsampling_factor,
828
+ feat_in=feat_in,
829
+ feat_out=d_model,
830
+ conv_channels=d_model,
831
+ )
832
+
833
+ self.pos_emb_max_len = pos_emb_max_len
834
+ if self_attention_model == "rotary":
835
+ self.pos_enc: PositionalEncoding = RotaryPositionalEmbedding(
836
+ d_model // n_heads, pos_emb_max_len
837
+ )
838
+ else:
839
+ self.pos_enc = RelPositionalEmbedding(d_model, pos_emb_max_len)
840
+
841
+ self.layers = nn.ModuleList()
842
+ for _ in range(n_layers):
843
+ layer = ConformerLayer(
844
+ d_model=d_model,
845
+ d_ff=d_model * ff_expansion_factor,
846
+ self_attention_model=self_attention_model,
847
+ n_heads=n_heads,
848
+ conv_norm_type=conv_norm_type,
849
+ conv_kernel_size=conv_kernel_size,
850
+ flash_attn=flash_attn,
851
+ )
852
+ self.layers.append(layer)
853
+
854
+ def input_example(
855
+ self,
856
+ batch_size: int = 1,
857
+ seqlen: int = 200,
858
+ ) -> Tuple[Tensor, Tensor]:
859
+ device = next(self.parameters()).device
860
+ features = torch.zeros(batch_size, self.feat_in, seqlen)
861
+ feature_lengths = torch.full([batch_size], features.shape[-1])
862
+ return features.float().to(device), feature_lengths.to(device)
863
+
864
+ def input_names(self) -> List[str]:
865
+ return ["audio_signal", "length"]
866
+
867
+ def output_names(self) -> List[str]:
868
+ return ["encoded", "encoded_len"]
869
+
870
+ def dynamic_axes(self) -> Dict[str, Dict[int, str]]:
871
+ return {
872
+ "audio_signal": {0: "batch_size", 2: "seq_len"},
873
+ "length": {0: "batch_size"},
874
+ "encoded": {0: "batch_size", 1: "seq_len"},
875
+ "encoded_len": {0: "batch_size"},
876
+ }
877
+
878
+ def forward(self, audio_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]:
879
+ if not hasattr(self.pos_enc, "pe"):
880
+ self.pos_enc.extend_pe(self.pos_emb_max_len, audio_signal.device)
881
+
882
+ audio_signal, length = self.pre_encode(
883
+ x=audio_signal.transpose(1, 2), lengths=length
884
+ )
885
+
886
+ max_len = audio_signal.size(1)
887
+ audio_signal, pos_emb = self.pos_enc(x=audio_signal)
888
+
889
+ pad_mask = torch.arange(0, max_len, device=audio_signal.device).expand(
890
+ length.size(0), -1
891
+ ) < length.unsqueeze(-1)
892
+
893
+ att_mask = None
894
+ if audio_signal.shape[0] > 1:
895
+ att_mask = pad_mask.unsqueeze(1).repeat([1, max_len, 1])
896
+ att_mask = torch.logical_and(att_mask, att_mask.transpose(1, 2))
897
+ att_mask = ~att_mask
898
+
899
+ pad_mask = ~pad_mask
900
+
901
+ for layer in self.layers:
902
+ audio_signal = layer(
903
+ x=audio_signal,
904
+ pos_emb=pos_emb,
905
+ att_mask=att_mask,
906
+ pad_mask=pad_mask,
907
+ )
908
+
909
+ return audio_signal.transpose(1, 2), length
910
+
911
+
912
+ ### decoders ###
913
+
914
+
915
+ class CTCHead(nn.Module):
916
+ """
917
+ CTC Head module for Connectionist Temporal Classification.
918
+ """
919
+
920
+ def __init__(self, feat_in: int, num_classes: int):
921
+ super().__init__()
922
+ self.decoder_layers = torch.nn.Sequential(
923
+ torch.nn.Conv1d(feat_in, num_classes, kernel_size=1)
924
+ )
925
+
926
+ def forward(self, encoder_output: Tensor) -> Tensor:
927
+ return torch.nn.functional.log_softmax(
928
+ self.decoder_layers(encoder_output).transpose(1, 2), dim=-1
929
+ )
930
+
931
+
932
+ class RNNTJoint(nn.Module):
933
+ """
934
+ RNN-Transducer Joint Network Module.
935
+ This module combines the outputs of the encoder and the prediction network using
936
+ a linear transformation followed by ReLU activation and another linear projection.
937
+ """
938
+
939
+ def __init__(
940
+ self, enc_hidden: int, pred_hidden: int, joint_hidden: int, num_classes: int
941
+ ):
942
+ super().__init__()
943
+ self.enc_hidden = enc_hidden
944
+ self.pred_hidden = pred_hidden
945
+ self.pred = nn.Linear(pred_hidden, joint_hidden)
946
+ self.enc = nn.Linear(enc_hidden, joint_hidden)
947
+ self.joint_net = nn.Sequential(nn.ReLU(), nn.Linear(joint_hidden, num_classes))
948
+
949
+ def joint(self, encoder_out: Tensor, decoder_out: Tensor) -> Tensor:
950
+ """
951
+ Combine the encoder and prediction network outputs into a joint representation.
952
+ """
953
+ enc = self.enc(encoder_out).unsqueeze(2)
954
+ pred = self.pred(decoder_out).unsqueeze(1)
955
+ return self.joint_net(enc + pred).log_softmax(-1)
956
+
957
+ def input_example(self) -> Tuple[Tensor, Tensor]:
958
+ device = next(self.parameters()).device
959
+ enc = torch.zeros(1, self.enc_hidden, 1)
960
+ dec = torch.zeros(1, self.pred_hidden, 1)
961
+ return enc.float().to(device), dec.float().to(device)
962
+
963
+ def input_names(self) -> List[str]:
964
+ return ["enc", "dec"]
965
+
966
+ def output_names(self) -> List[str]:
967
+ return ["joint"]
968
+
969
+ def forward(self, enc: Tensor, dec: Tensor) -> Tensor:
970
+ return self.joint(enc.transpose(1, 2), dec.transpose(1, 2))
971
+
972
+
973
+ class RNNTDecoder(nn.Module):
974
+ """
975
+ RNN-Transducer Decoder Module.
976
+ This module handles the prediction network part of the RNN-Transducer architecture.
977
+ """
978
+
979
+ def __init__(self, pred_hidden: int, pred_rnn_layers: int, num_classes: int):
980
+ super().__init__()
981
+ self.blank_id = num_classes - 1
982
+ self.pred_hidden = pred_hidden
983
+ self.embed = nn.Embedding(num_classes, pred_hidden, padding_idx=self.blank_id)
984
+ self.lstm = nn.LSTM(pred_hidden, pred_hidden, pred_rnn_layers)
985
+
986
+ def predict(
987
+ self,
988
+ x: Optional[Tensor],
989
+ state: Optional[Tensor],
990
+ batch_size: int = 1,
991
+ ) -> Tuple[Tensor, Tensor]:
992
+ """
993
+ Make predictions based on the current input and previous states.
994
+ If no input is provided, use zeros as the initial input.
995
+ """
996
+ if x is not None:
997
+ emb: Tensor = self.embed(x)
998
+ else:
999
+ emb = torch.zeros(
1000
+ (batch_size, 1, self.pred_hidden), device=next(self.parameters()).device
1001
+ )
1002
+ g, hid = self.lstm(emb.transpose(0, 1), state)
1003
+ return g.transpose(0, 1), hid
1004
+
1005
+ def input_example(self) -> Tuple[Tensor, Tensor, Tensor]:
1006
+ device = next(self.parameters()).device
1007
+ label = torch.tensor([[0]]).to(device)
1008
+ hidden_h = torch.zeros(1, 1, self.pred_hidden).to(device)
1009
+ hidden_c = torch.zeros(1, 1, self.pred_hidden).to(device)
1010
+ return label, hidden_h, hidden_c
1011
+
1012
+ def input_names(self) -> List[str]:
1013
+ return ["x", "h", "c"]
1014
+
1015
+ def output_names(self) -> List[str]:
1016
+ return ["dec", "h", "c"]
1017
+
1018
+ def forward(self, x: Tensor, h: Tensor, c: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
1019
+ """
1020
+ ONNX-specific forward with x, state = (h, c) -> x, h, c.
1021
+ """
1022
+ emb = self.embed(x)
1023
+ g, (h, c) = self.lstm(emb.transpose(0, 1), (h, c))
1024
+ return g.transpose(0, 1), h, c
1025
+
1026
+
1027
+ class RNNTHead(nn.Module):
1028
+ """
1029
+ RNN-Transducer Head Module.
1030
+ This module combines the decoder and joint network components of the RNN-Transducer architecture.
1031
+ """
1032
+
1033
+ def __init__(self, decoder: Dict[str, int], joint: Dict[str, int]):
1034
+ super().__init__()
1035
+ self.decoder = RNNTDecoder(**decoder)
1036
+ self.joint = RNNTJoint(**joint)
1037
+
1038
+
1039
+ ### decoding ###
1040
+
1041
+
1042
+ class Tokenizer:
1043
+ """
1044
+ Tokenizer for converting between text and token IDs.
1045
+ The tokenizer can operate either character-wise or using a pre-trained SentencePiece model.
1046
+ """
1047
+
1048
+ def __init__(self, vocab: List[str], model_path: Optional[str] = None):
1049
+ self.charwise = model_path is None
1050
+ if self.charwise:
1051
+ self.vocab = vocab
1052
+ else:
1053
+ self.model = SentencePieceProcessor()
1054
+ self.model.load(model_path)
1055
+
1056
+ def decode(self, tokens: List[int]) -> str:
1057
+ """
1058
+ Convert a list of token IDs back to a string.
1059
+ """
1060
+ if self.charwise:
1061
+ return "".join(self.vocab[tok] for tok in tokens)
1062
+ return self.model.decode(tokens)
1063
+
1064
+ def __len__(self):
1065
+ """
1066
+ Get the total number of tokens in the vocabulary.
1067
+ """
1068
+ return len(self.vocab) if self.charwise else len(self.model)
1069
+
1070
+
1071
+ class CTCGreedyDecoding:
1072
+ """
1073
+ Class for performing greedy decoding of CTC outputs.
1074
+ """
1075
+
1076
+ def __init__(self, vocabulary: List[str], model_path: Optional[str] = None):
1077
+ self.tokenizer = Tokenizer(vocabulary, model_path)
1078
+ self.blank_id = len(self.tokenizer)
1079
+
1080
+ @torch.inference_mode()
1081
+ def decode(self, head: CTCHead, encoded: Tensor, lengths: Tensor) -> List[str]:
1082
+ """
1083
+ Decode the output of a CTC model into a list of hypotheses.
1084
+ """
1085
+ log_probs = head(encoder_output=encoded)
1086
+ assert (
1087
+ len(log_probs.shape) == 3
1088
+ ), f"Expected log_probs shape {log_probs.shape} == [B, T, C]"
1089
+ b, _, c = log_probs.shape
1090
+ assert (
1091
+ c == len(self.tokenizer) + 1
1092
+ ), f"Num classes {c} != len(vocab) + 1 {len(self.tokenizer) + 1}"
1093
+ labels = log_probs.argmax(dim=-1, keepdim=False)
1094
+
1095
+ skip_mask = labels != self.blank_id
1096
+ skip_mask[:, 1:] = torch.logical_and(
1097
+ skip_mask[:, 1:], labels[:, 1:] != labels[:, :-1]
1098
+ )
1099
+ for i, length in enumerate(lengths):
1100
+ skip_mask[i, length:] = 0
1101
+
1102
+ pred_texts: List[str] = []
1103
+ for i in range(b):
1104
+ pred_texts.append(
1105
+ "".join(self.tokenizer.decode(labels[i][skip_mask[i]].cpu().tolist()))
1106
+ )
1107
+ return pred_texts
1108
+
1109
+
1110
+ class RNNTGreedyDecoding:
1111
+ def __init__(
1112
+ self,
1113
+ vocabulary: List[str],
1114
+ model_path: Optional[str] = None,
1115
+ max_symbols_per_step: int = 10,
1116
+ ):
1117
+ """
1118
+ Class for performing greedy decoding of RNN-T outputs.
1119
+ """
1120
+ self.tokenizer = Tokenizer(vocabulary, model_path)
1121
+ self.blank_id = len(self.tokenizer)
1122
+ self.max_symbols = max_symbols_per_step
1123
+
1124
+ def _greedy_decode(self, head: RNNTHead, x: Tensor, seqlen: Tensor) -> str:
1125
+ """
1126
+ Internal helper function for performing greedy decoding on a single sequence.
1127
+ """
1128
+ hyp: List[int] = []
1129
+ dec_state: Optional[Tensor] = None
1130
+ last_label: Optional[Tensor] = None
1131
+ for t in range(seqlen):
1132
+ f = x[t, :, :].unsqueeze(1)
1133
+ not_blank = True
1134
+ new_symbols = 0
1135
+ while not_blank and new_symbols < self.max_symbols:
1136
+ g, hidden = head.decoder.predict(last_label, dec_state)
1137
+ k = head.joint.joint(f, g)[0, 0, 0, :].argmax(0).item()
1138
+ if k == self.blank_id:
1139
+ not_blank = False
1140
+ else:
1141
+ hyp.append(int(k))
1142
+ dec_state = hidden
1143
+ last_label = torch.tensor([[hyp[-1]]]).to(x.device)
1144
+ new_symbols += 1
1145
+
1146
+ return self.tokenizer.decode(hyp)
1147
+
1148
+ @torch.inference_mode()
1149
+ def decode(self, head: RNNTHead, encoded: Tensor, enc_len: Tensor) -> List[str]:
1150
+ """
1151
+ Decode the output of an RNN-T model into a list of hypotheses.
1152
+ """
1153
+ b = encoded.shape[0]
1154
+ pred_texts = []
1155
+ encoded = encoded.transpose(1, 2)
1156
+ for i in range(b):
1157
+ inseq = encoded[i, :, :].unsqueeze(1)
1158
+ pred_texts.append(self._greedy_decode(head, inseq, enc_len[i]))
1159
+ return pred_texts
1160
+
1161
+
1162
+ ### models ###
1163
+
1164
+
1165
+ class GigaAM(nn.Module):
1166
+ """
1167
+ Giga Acoustic Model: Self-Supervised Model for Speech Tasks
1168
+ """
1169
+
1170
+ def __init__(self, cfg: omegaconf.DictConfig):
1171
+ super().__init__()
1172
+ self.cfg = cfg
1173
+ self.preprocessor = hydra.utils.instantiate(self.cfg.preprocessor)
1174
+ self.encoder = hydra.utils.instantiate(self.cfg.encoder)
1175
+
1176
+ def forward(
1177
+ self, features: Tensor, feature_lengths: Tensor
1178
+ ) -> Tuple[Tensor, Tensor]:
1179
+ """
1180
+ Perform forward pass through the preprocessor and encoder.
1181
+ """
1182
+ features, feature_lengths = self.preprocessor(features, feature_lengths)
1183
+ if self._device.type == "cpu":
1184
+ return self.encoder(features, feature_lengths)
1185
+ with torch.autocast(device_type=self._device.type, dtype=torch.float16):
1186
+ return self.encoder(features, feature_lengths)
1187
+
1188
+ @property
1189
+ def _device(self) -> torch.device:
1190
+ return next(self.parameters()).device
1191
+
1192
+ @property
1193
+ def _dtype(self) -> torch.dtype:
1194
+ return next(self.parameters()).dtype
1195
+
1196
+ def prepare_wav(self, wav_file: str) -> Tuple[Tensor, Tensor]:
1197
+ """
1198
+ Prepare an audio file for processing by loading it onto
1199
+ the correct device and converting its format.
1200
+ """
1201
+ wav = load_audio(wav_file)
1202
+ wav = wav.to(self._device).to(self._dtype).unsqueeze(0)
1203
+ length = torch.full([1], wav.shape[-1], device=self._device)
1204
+ return wav, length
1205
+
1206
+ def embed_audio(self, wav_file: str) -> Tuple[Tensor, Tensor]:
1207
+ """
1208
+ Extract audio representations using the GigaAM model.
1209
+ """
1210
+ wav, length = self.prepare_wav(wav_file)
1211
+ encoded, encoded_len = self.forward(wav, length)
1212
+ return encoded, encoded_len
1213
+
1214
+ def to_onnx(self, dir_path: str = ".") -> None:
1215
+ """
1216
+ Export onnx model encoder to the specified dir.
1217
+ """
1218
+ self._to_onnx(dir_path)
1219
+ omegaconf.OmegaConf.save(self.cfg, f"{dir_path}/{self.cfg.model_name}.yaml")
1220
+
1221
+ def _to_onnx(self, dir_path: str = ".") -> None:
1222
+ """
1223
+ Export onnx model encoder to the specified dir.
1224
+ """
1225
+ onnx_converter(
1226
+ model_name=f"{self.cfg.model_name}_encoder",
1227
+ out_dir=dir_path,
1228
+ module=self.encoder,
1229
+ dynamic_axes=self.encoder.dynamic_axes(),
1230
+ )
1231
+
1232
+
1233
+ class GigaAMASR(GigaAM):
1234
+ """
1235
+ Giga Acoustic Model for Speech Recognition
1236
+ """
1237
+
1238
+ def __init__(self, cfg: omegaconf.DictConfig):
1239
+ super().__init__(cfg)
1240
+ self.head = hydra.utils.instantiate(self.cfg.head)
1241
+ self.decoding = hydra.utils.instantiate(self.cfg.decoding)
1242
+
1243
+ @torch.inference_mode()
1244
+ def transcribe(self, wav_file: str) -> str:
1245
+ """
1246
+ Transcribes a short audio file into text.
1247
+ """
1248
+ wav, length = self.prepare_wav(wav_file)
1249
+ if length.item() > LONGFORM_THRESHOLD:
1250
+ raise ValueError("Too long wav file, use 'transcribe_longform' method.")
1251
+
1252
+ encoded, encoded_len = self.forward(wav, length)
1253
+ return self.decoding.decode(self.head, encoded, encoded_len)[0]
1254
+
1255
+ def forward_for_export(self, features: Tensor, feature_lengths: Tensor) -> Tensor:
1256
+ """
1257
+ Encoder-decoder forward to save model entirely in onnx format.
1258
+ """
1259
+ return self.head(self.encoder(features, feature_lengths)[0])
1260
+
1261
+ def _to_onnx(self, dir_path: str = ".") -> None:
1262
+ """
1263
+ Export onnx ASR model.
1264
+ `ctc`: exported entirely in encoder-decoder format.
1265
+ `rnnt`: exported in encoder/decoder/joint parts separately.
1266
+ """
1267
+ if "ctc" in self.cfg.model_name:
1268
+ saved_forward = self.forward
1269
+ self.forward = self.forward_for_export # type: ignore[assignment, method-assign]
1270
+ onnx_converter(
1271
+ model_name=self.cfg.model_name,
1272
+ out_dir=dir_path,
1273
+ module=self,
1274
+ inputs=self.encoder.input_example(),
1275
+ input_names=["features", "feature_lengths"],
1276
+ output_names=["log_probs"],
1277
+ dynamic_axes={
1278
+ "features": {0: "batch_size", 2: "seq_len"},
1279
+ "feature_lengths": {0: "batch_size"},
1280
+ "log_probs": {0: "batch_size", 1: "seq_len"},
1281
+ },
1282
+ )
1283
+ self.forward = saved_forward # type: ignore[assignment, method-assign]
1284
+ else:
1285
+ super()._to_onnx(dir_path) # export encoder
1286
+ onnx_converter(
1287
+ model_name=f"{self.cfg.model_name}_decoder",
1288
+ out_dir=dir_path,
1289
+ module=self.head.decoder,
1290
+ )
1291
+ onnx_converter(
1292
+ model_name=f"{self.cfg.model_name}_joint",
1293
+ out_dir=dir_path,
1294
+ module=self.head.joint,
1295
+ )
1296
+
1297
+ @torch.inference_mode()
1298
+ def transcribe_longform(
1299
+ self, wav_file: str, **kwargs
1300
+ ) -> List[Dict[str, Union[str, Tuple[float, float]]]]:
1301
+ """
1302
+ Transcribes a long audio file by splitting it into segments and
1303
+ then transcribing each segment.
1304
+ """
1305
+ transcribed_segments = []
1306
+ segments, boundaries = segment_audio_file(
1307
+ wav_file, SAMPLE_RATE, device=self._device, **kwargs
1308
+ )
1309
+ for segment, segment_boundaries in zip(segments, boundaries):
1310
+ wav = segment.to(self._device).unsqueeze(0).to(self._dtype)
1311
+ length = torch.full([1], wav.shape[-1], device=self._device)
1312
+ encoded, encoded_len = self.forward(wav, length)
1313
+ result = self.decoding.decode(self.head, encoded, encoded_len)[0]
1314
+ transcribed_segments.append(
1315
+ {"transcription": result, "boundaries": segment_boundaries}
1316
+ )
1317
+ return transcribed_segments
1318
+
1319
+
1320
+ class GigaAMEmo(GigaAM):
1321
+ """
1322
+ Giga Acoustic Model for Emotion Recognition
1323
+ """
1324
+
1325
+ def __init__(self, cfg: omegaconf.DictConfig):
1326
+ super().__init__(cfg)
1327
+ self.head = hydra.utils.instantiate(self.cfg.head)
1328
+ self.id2name = cfg.id2name
1329
+
1330
+ def get_probs(self, wav_file: str) -> Dict[str, float]:
1331
+ """
1332
+ Calculate probabilities for each emotion class based on the provided audio file.
1333
+ """
1334
+ wav, length = self.prepare_wav(wav_file)
1335
+ encoded, _ = self.forward(wav, length)
1336
+ encoded_pooled = nn.functional.avg_pool1d(
1337
+ encoded, kernel_size=encoded.shape[-1]
1338
+ ).squeeze(-1)
1339
+
1340
+ logits = self.head(encoded_pooled)[0]
1341
+ probs = nn.functional.softmax(logits, dim=-1).detach().tolist()
1342
+
1343
+ return {self.id2name[i]: probs[i] for i in range(len(self.id2name))}
1344
+
1345
+ def forward_for_export(self, features: Tensor, feature_lengths: Tensor) -> Tensor:
1346
+ """
1347
+ Encoder-decoder forward to save model entirely in onnx format.
1348
+ """
1349
+ encoded, _ = self.encoder(features, feature_lengths)
1350
+ enc_pooled = encoded.mean(dim=-1)
1351
+ return nn.functional.softmax(self.head(enc_pooled), dim=-1)
1352
+
1353
+ def _to_onnx(self, dir_path: str = ".") -> None:
1354
+ """
1355
+ Export onnx Emo model.
1356
+ """
1357
+ saved_forward = self.forward
1358
+ self.forward = self.forward_for_export # type: ignore[assignment, method-assign]
1359
+ onnx_converter(
1360
+ model_name=self.cfg.model_name,
1361
+ out_dir=dir_path,
1362
+ module=self,
1363
+ inputs=self.encoder.input_example(),
1364
+ input_names=["features", "feature_lengths"],
1365
+ output_names=["probs"],
1366
+ dynamic_axes={
1367
+ "features": {0: "batch_size", 2: "seq_len"},
1368
+ "feature_lengths": {0: "batch_size"},
1369
+ "probs": {0: "batch_size", 1: "seq_len"},
1370
+ },
1371
+ )
1372
+ self.forward = saved_forward # type: ignore[assignment, method-assign]
1373
+
1374
+
1375
+ ### transformers ###
1376
+
1377
+
1378
+ class GigaAMConfig(PretrainedConfig):
1379
+ model_type = "gigaam"
1380
+
1381
+ def __init__(self, cfg: omegaconf.DictConfig = None, **kwargs):
1382
+ super().__init__(**kwargs)
1383
+ self.cfg = cfg
1384
+
1385
+
1386
+ class GigaAMModel(PreTrainedModel):
1387
+ config_class = GigaAMConfig
1388
+ base_model_prefix = "gigaam"
1389
+
1390
+ def __init__(self, config: GigaAMConfig):
1391
+ super().__init__(config)
1392
+ self.config = config
1393
+ if "decoding" in self.config.cfg["model"]["cfg"] and "model_path" in self.config.cfg["model"]["cfg"]["decoding"]:
1394
+ resolved_tokenizer_path = cached_file(
1395
+ config.name_or_path,
1396
+ "tokenizer.model",
1397
+ revision=getattr(config, "_commit_hash", None),
1398
+ cache_dir=getattr(config, "cache_dir", None),
1399
+ use_auth_token=getattr(config, "use_auth_token", None),
1400
+ )
1401
+ self.config.cfg["model"]["cfg"]["decoding"]["model_path"] = resolved_tokenizer_path
1402
+
1403
+ self.model = instantiate(config.cfg["model"], _recursive_=False)
1404
+
1405
+ def forward(self, features: torch.Tensor, feature_lengths: torch.Tensor):
1406
+ return self.model(features, feature_lengths)
1407
+
1408
+ def embed_audio(self, wav_file: str) -> torch.Tensor:
1409
+ return self.model.embed_audio(wav_file)
1410
+
1411
+ def transcribe(self, wav_file: str) -> str:
1412
+ return self.model.transcribe(wav_file)
1413
+
1414
+ def transcribe_longform(self, wav_file: str) -> List[Dict[str, Union[str, Tuple[float, float]]]]:
1415
+ return self.model.transcribe_longform(wav_file)
1416
+
1417
+ def get_probs(self, wav_file: str) -> Dict[str, float]:
1418
+ return self.model.get_probs(wav_file)
1419
+
1420
+ @torch.no_grad()
1421
+ def to_onnx(self, dir_path: str = ".") -> None:
1422
+ self.model.to_onnx(dir_path)
1423
+
1424
+ @classmethod
1425
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1426
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afc6dcbae8320ea56f2cddebc0f13fbf62c9d59b6ddcad899782623c8610826a
3
+ size 448928167
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:828c12c991019eef952a960661f25a92d6ad279591e2ea466b4aeddf1d20a18a
3
+ size 255336