AvtnshM commited on
Commit
771dc21
Β·
verified Β·
1 Parent(s): d0747b6
Files changed (1) hide show
  1. app.py +145 -110
app.py CHANGED
@@ -2,206 +2,241 @@ import gradio as gr
2
  import torch
3
  import librosa
4
  import numpy as np
5
- from transformers import pipeline
6
  import gc
7
  import warnings
 
8
  warnings.filterwarnings("ignore")
9
 
10
- class OptimizedShukaASR:
 
 
 
 
11
  def __init__(self):
12
  self.pipe = None
13
- self.load_model()
14
-
15
- def load_model(self):
16
- """Load model with optimizations for CPU inference"""
 
 
 
17
  try:
18
- # Force CPU usage and optimize for inference
 
 
19
  self.pipe = pipeline(
20
  model='sarvamai/shuka_v1',
21
  trust_remote_code=True,
22
- device=-1, # Force CPU
23
  model_kwargs={
24
- "torch_dtype": torch.float32, # Use float32 for CPU
25
  "low_cpu_mem_usage": True,
26
- "use_cache": True,
 
27
  }
28
  )
29
 
30
- # Set to eval mode and optimize
31
- if hasattr(self.pipe.model, 'eval'):
32
- self.pipe.model.eval()
33
-
34
- # Compile for faster inference (PyTorch 2.0+)
35
- try:
36
- self.pipe.model = torch.compile(self.pipe.model, mode="reduce-overhead")
37
- except:
38
- pass # Skip if torch.compile not available
39
-
40
- print("Model loaded successfully with optimizations")
41
 
42
  except Exception as e:
43
- print(f"Error loading model: {e}")
44
- self.pipe = None
45
 
46
- def preprocess_audio(self, audio_input, target_sr=16000, max_duration=30):
47
- """Preprocess audio with length limiting and optimization"""
48
  try:
49
  if isinstance(audio_input, tuple):
50
  sr, audio_data = audio_input
51
  audio_data = audio_data.astype(np.float32)
52
  if len(audio_data.shape) > 1:
53
- audio_data = audio_data.mean(axis=1) # Convert to mono
54
- audio_data = audio_data / np.max(np.abs(audio_data)) # Normalize
55
  else:
56
- audio_data, sr = librosa.load(audio_input, sr=target_sr)
57
 
58
- # Limit audio duration to reduce processing time
59
- max_samples = int(max_duration * target_sr)
60
- if len(audio_data) > max_samples:
61
- audio_data = audio_data[:max_samples]
62
- print(f"Audio truncated to {max_duration} seconds")
63
 
64
- # Resample if needed
65
- if sr != target_sr:
66
- audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=target_sr)
67
 
68
  return audio_data, target_sr
69
 
70
  except Exception as e:
71
  raise Exception(f"Audio preprocessing failed: {e}")
72
 
73
- def transcribe(self, audio_input, language="auto"):
74
- """Transcribe audio to text"""
75
- if self.pipe is None:
76
- return "Model not loaded. Please check the setup."
 
 
77
 
78
  try:
79
- # Preprocess audio
80
- audio, sr = self.preprocess_audio(audio_input)
81
 
82
- # Prepare system prompt for ASR only
83
- if language == "auto":
84
- system_prompt = "Transcribe the following audio accurately. Only provide the transcription, nothing else."
85
- else:
86
- system_prompt = f"Transcribe the following audio in {language}. Only provide the transcription, nothing else."
87
 
88
  turns = [
89
- {'role': 'system', 'content': system_prompt},
90
  {'role': 'user', 'content': '<|audio|>'}
91
  ]
92
 
93
- # Run inference with memory optimization
94
- with torch.no_grad():
95
  result = self.pipe(
96
  {
97
  'audio': audio,
98
  'turns': turns,
99
  'sampling_rate': sr
100
  },
101
- max_new_tokens=256, # Reduced for ASR only
102
- do_sample=False, # Deterministic output
103
- temperature=0.1, # Low temperature for accuracy
104
- pad_token_id=self.pipe.tokenizer.eos_token_id
 
105
  )
106
 
107
- # Clean up memory
108
- if torch.cuda.is_available():
109
- torch.cuda.empty_cache()
110
  gc.collect()
111
 
112
- # Extract transcription
113
  if isinstance(result, list) and len(result) > 0:
114
- transcription = result[0].get('generated_text', '').strip()
115
  elif isinstance(result, dict):
116
- transcription = result.get('generated_text', '').strip()
117
  else:
118
- transcription = str(result).strip()
119
 
120
- return transcription
 
 
 
 
 
 
121
 
122
  except Exception as e:
123
- return f"Transcription failed: {str(e)}"
124
 
125
- # Initialize the ASR system
126
- asr_system = OptimizedShukaASR()
 
127
 
128
- def transcribe_audio(audio, language):
129
- """Gradio interface function"""
130
  if audio is None:
131
- return "Please provide an audio file."
132
 
133
- result = asr_system.transcribe(audio, language)
134
- return result
135
 
136
- # Language options
137
- languages = [
138
- ("Auto-detect", "auto"),
139
  ("English", "english"),
140
  ("Hindi", "hindi"),
141
  ("Bengali", "bengali"),
 
 
142
  ("Gujarati", "gujarati"),
143
  ("Kannada", "kannada"),
144
  ("Malayalam", "malayalam"),
145
  ("Marathi", "marathi"),
146
- ("Oriya", "oriya"),
147
  ("Punjabi", "punjabi"),
148
- ("Tamil", "tamil"),
149
- ("Telugu", "telugu")
150
  ]
151
 
152
- # Create Gradio interface
153
- with gr.Blocks(title="Shuka v1 ASR - Multilingual Speech Recognition") as demo:
154
- gr.Markdown("# πŸŽ™οΈ Shuka v1 ASR - Fast Multilingual Transcription")
155
- gr.Markdown("Upload an audio file or record directly to get transcription in multiple Indic languages.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  with gr.Row():
158
- with gr.Column():
159
  audio_input = gr.Audio(
160
- label="Audio Input",
161
  type="filepath",
162
- format="wav"
 
163
  )
164
- language_dropdown = gr.Dropdown(
165
- choices=languages,
 
166
  value="auto",
167
- label="Language (optional)"
 
 
 
 
 
 
 
 
 
168
  )
169
- transcribe_btn = gr.Button("πŸš€ Transcribe", variant="primary")
170
 
171
- with gr.Column():
172
- output_text = gr.Textbox(
173
- label="Transcription",
174
- placeholder="Transcription will appear here...",
175
- lines=10
176
  )
177
 
178
- # Event handlers
179
- transcribe_btn.click(
180
- fn=transcribe_audio,
181
- inputs=[audio_input, language_dropdown],
182
- outputs=output_text
 
183
  )
184
 
185
- # Auto-transcribe on audio upload
186
- audio_input.change(
187
- fn=transcribe_audio,
188
- inputs=[audio_input, language_dropdown],
189
- outputs=output_text
 
190
  )
191
 
192
- # Examples section
193
- gr.Markdown("## πŸ“ Tips for best results:")
194
- gr.Markdown("""
195
- - Audio files are automatically limited to 30 seconds for faster processing
196
- - Supported formats: WAV, MP3, M4A, WEBM
197
- - For best accuracy, use clear audio with minimal background noise
198
- - The model supports 11 Indic languages + English
 
 
 
 
199
  """)
200
 
201
  if __name__ == "__main__":
 
202
  demo.launch(
203
  server_name="0.0.0.0",
204
  server_port=7860,
205
  share=False,
206
- show_error=True
 
207
  )
 
2
  import torch
3
  import librosa
4
  import numpy as np
5
+ from transformers import pipeline, AutoConfig
6
  import gc
7
  import warnings
8
+ import os
9
  warnings.filterwarnings("ignore")
10
 
11
+ # Set environment variables for optimization
12
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+ class UltraLightShukaASR:
16
  def __init__(self):
17
  self.pipe = None
18
+ self.model_loaded = False
19
+
20
+ def load_model_lazy(self):
21
+ """Lazy load model only when needed"""
22
+ if self.model_loaded:
23
+ return True
24
+
25
  try:
26
+ print("Loading Shuka v1 model...")
27
+
28
+ # Try with minimal configuration first
29
  self.pipe = pipeline(
30
  model='sarvamai/shuka_v1',
31
  trust_remote_code=True,
32
+ device=-1, # CPU only
33
  model_kwargs={
 
34
  "low_cpu_mem_usage": True,
35
+ "use_cache": False, # Disable cache to save memory
36
+ "torch_dtype": torch.float32,
37
  }
38
  )
39
 
40
+ print("βœ… Model loaded successfully!")
41
+ self.model_loaded = True
42
+ return True
 
 
 
 
 
 
 
 
43
 
44
  except Exception as e:
45
+ print(f"❌ Model loading failed: {e}")
46
+ return False
47
 
48
+ def preprocess_audio_minimal(self, audio_input, target_sr=16000, max_duration=15):
49
+ """Minimal audio preprocessing for speed"""
50
  try:
51
  if isinstance(audio_input, tuple):
52
  sr, audio_data = audio_input
53
  audio_data = audio_data.astype(np.float32)
54
  if len(audio_data.shape) > 1:
55
+ audio_data = np.mean(audio_data, axis=1)
 
56
  else:
57
+ audio_data, sr = librosa.load(audio_input, sr=target_sr, duration=max_duration)
58
 
59
+ # Quick normalization
60
+ if np.max(np.abs(audio_data)) > 0:
61
+ audio_data = audio_data / np.max(np.abs(audio_data))
 
 
62
 
63
+ # Trim silence from start and end
64
+ audio_data, _ = librosa.effects.trim(audio_data, top_db=20)
 
65
 
66
  return audio_data, target_sr
67
 
68
  except Exception as e:
69
  raise Exception(f"Audio preprocessing failed: {e}")
70
 
71
+ def transcribe_fast(self, audio_input, language_hint=""):
72
+ """Fast transcription with minimal overhead"""
73
+
74
+ # Lazy load model
75
+ if not self.load_model_lazy():
76
+ return "❌ Model failed to load. Please check your setup."
77
 
78
  try:
79
+ # Quick audio processing
80
+ audio, sr = self.preprocess_audio_minimal(audio_input)
81
 
82
+ # Minimal system prompt for speed
83
+ system_content = "Transcribe audio to text."
84
+ if language_hint and language_hint != "auto":
85
+ system_content += f" Language: {language_hint}."
 
86
 
87
  turns = [
88
+ {'role': 'system', 'content': system_content},
89
  {'role': 'user', 'content': '<|audio|>'}
90
  ]
91
 
92
+ # Fast inference settings
93
+ with torch.inference_mode(): # More efficient than no_grad
94
  result = self.pipe(
95
  {
96
  'audio': audio,
97
  'turns': turns,
98
  'sampling_rate': sr
99
  },
100
+ max_new_tokens=128, # Reduced further
101
+ do_sample=False, # Deterministic
102
+ num_beams=1, # No beam search
103
+ early_stopping=True, # Stop as soon as possible
104
+ pad_token_id=self.pipe.tokenizer.eos_token_id if hasattr(self.pipe, 'tokenizer') else None
105
  )
106
 
107
+ # Immediate cleanup
108
+ del audio
 
109
  gc.collect()
110
 
111
+ # Extract result
112
  if isinstance(result, list) and len(result) > 0:
113
+ text = result[0].get('generated_text', '').strip()
114
  elif isinstance(result, dict):
115
+ text = result.get('generated_text', '').strip()
116
  else:
117
+ text = str(result).strip()
118
 
119
+ # Clean up the output (remove system prompts if they appear)
120
+ if "Transcribe audio to text" in text:
121
+ text = text.replace("Transcribe audio to text", "").strip()
122
+ if text.startswith("Language:"):
123
+ text = text.split(".", 1)[-1].strip() if "." in text else text
124
+
125
+ return text if text else "No speech detected"
126
 
127
  except Exception as e:
128
+ return f"❌ Transcription error: {str(e)}"
129
 
130
+ # Initialize ASR system
131
+ print("Initializing Ultra-Light Shuka ASR...")
132
+ asr_system = UltraLightShukaASR()
133
 
134
+ def process_audio(audio, language):
135
+ """Main processing function"""
136
  if audio is None:
137
+ return "Please upload or record an audio file."
138
 
139
+ return asr_system.transcribe_fast(audio, language)
 
140
 
141
+ # Simple language options
142
+ LANGUAGES = [
143
+ ("Auto", "auto"),
144
  ("English", "english"),
145
  ("Hindi", "hindi"),
146
  ("Bengali", "bengali"),
147
+ ("Tamil", "tamil"),
148
+ ("Telugu", "telugu"),
149
  ("Gujarati", "gujarati"),
150
  ("Kannada", "kannada"),
151
  ("Malayalam", "malayalam"),
152
  ("Marathi", "marathi"),
 
153
  ("Punjabi", "punjabi"),
154
+ ("Oriya", "oriya")
 
155
  ]
156
 
157
+ # Ultra-minimal Gradio interface
158
+ css = """
159
+ .gradio-container {
160
+ max-width: 800px !important;
161
+ }
162
+ .output-text textarea {
163
+ font-size: 16px !important;
164
+ }
165
+ """
166
+
167
+ with gr.Blocks(css=css, title="Fast Shuka ASR") as demo:
168
+
169
+ gr.HTML("""
170
+ <div style='text-align: center; margin-bottom: 20px;'>
171
+ <h1>πŸš€ Ultra-Fast Shuka v1 ASR</h1>
172
+ <p>Optimized for speed β€’ Multilingual β€’ 15-second max clips</p>
173
+ </div>
174
+ """)
175
 
176
  with gr.Row():
177
+ with gr.Column(scale=1):
178
  audio_input = gr.Audio(
179
+ label="πŸŽ™οΈ Audio Input",
180
  type="filepath",
181
+ format="wav",
182
+ elem_id="audio-input"
183
  )
184
+
185
+ language_select = gr.Dropdown(
186
+ choices=LANGUAGES,
187
  value="auto",
188
+ label="🌍 Language Hint",
189
+ info="Optional - helps with accuracy"
190
+ )
191
+
192
+ with gr.Column(scale=2):
193
+ output_box = gr.Textbox(
194
+ label="πŸ“ Transcription",
195
+ placeholder="Upload audio to see transcription here...",
196
+ lines=8,
197
+ elem_classes=["output-text"]
198
  )
 
199
 
200
+ gr.Button("πŸ”„ Clear", size="sm").click(
201
+ lambda: ("", None),
202
+ outputs=[output_box, audio_input]
 
 
203
  )
204
 
205
+ # Auto-transcribe on upload
206
+ audio_input.change(
207
+ fn=process_audio,
208
+ inputs=[audio_input, language_select],
209
+ outputs=output_box,
210
+ show_progress=True
211
  )
212
 
213
+ # Also trigger on language change
214
+ language_select.change(
215
+ fn=process_audio,
216
+ inputs=[audio_input, language_select],
217
+ outputs=output_box,
218
+ show_progress=True
219
  )
220
 
221
+ gr.HTML("""
222
+ <div style='margin-top: 20px; padding: 15px; background: #f0f0f0; border-radius: 10px;'>
223
+ <h4>⚑ Speed Optimizations Active:</h4>
224
+ <ul style='margin: 10px 0;'>
225
+ <li>βœ… Auto audio trimming (15s max)</li>
226
+ <li>βœ… CPU-optimized inference</li>
227
+ <li>βœ… Minimal token generation</li>
228
+ <li>βœ… Memory cleanup after each request</li>
229
+ </ul>
230
+ <p><strong>Tip:</strong> For fastest results, use short, clear audio clips in WAV format.</p>
231
+ </div>
232
  """)
233
 
234
  if __name__ == "__main__":
235
+ demo.queue(max_size=3) # Limit concurrent requests
236
  demo.launch(
237
  server_name="0.0.0.0",
238
  server_port=7860,
239
  share=False,
240
+ show_error=True,
241
+ quiet=False
242
  )