SeasonalFall84 commited on
Commit
8e9e85e
·
1 Parent(s): 8fcb613

Merge exp/tts-modal-env: Add TTS AudioRefiner with LLM polish, Modal deployment, and async fixes

Browse files
.gitignore CHANGED
@@ -57,6 +57,9 @@ reference_repos/DeepCritical/
57
  # Keep the README in reference_repos
58
  !reference_repos/README.md
59
 
 
 
 
60
  # OS
61
  .DS_Store
62
  Thumbs.db
 
57
  # Keep the README in reference_repos
58
  !reference_repos/README.md
59
 
60
+ # Development directory
61
+ dev/
62
+
63
  # OS
64
  .DS_Store
65
  Thumbs.db
deployments/README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deployments
2
+
3
+ This directory contains infrastructure deployment scripts for DeepCritical services.
4
+
5
+ ## Modal Deployments
6
+
7
+ ### TTS Service (`modal_tts.py`)
8
+
9
+ Deploys the Kokoro TTS (Text-to-Speech) function to Modal's GPU infrastructure.
10
+
11
+ **Deploy:**
12
+ ```bash
13
+ modal deploy deployments/modal_tts.py
14
+ ```
15
+
16
+ **Features:**
17
+ - Kokoro 82M TTS model
18
+ - GPU-accelerated (T4)
19
+ - Voice options: af_heart, af_bella, am_michael, etc.
20
+ - Configurable speech speed
21
+
22
+ **Requirements:**
23
+ - Modal account and credentials (`MODAL_TOKEN_ID`, `MODAL_TOKEN_SECRET` in `.env`)
24
+ - GPU quota on Modal
25
+
26
+ **After Deployment:**
27
+ The function will be available at:
28
+ - App: `deepcritical-tts`
29
+ - Function: `kokoro_tts_function`
30
+
31
+ The main application (`src/services/tts_modal.py`) will call this deployed function.
32
+
33
+ ---
34
+
35
+ ## Adding New Deployments
36
+
37
+ When adding new deployment scripts:
38
+
39
+ 1. Create a new file: `deployments/<service_name>.py`
40
+ 2. Use Modal's app pattern:
41
+ ```python
42
+ import modal
43
+ app = modal.App("deepcritical-<service-name>")
44
+ ```
45
+ 3. Document in this README
46
+ 4. Test deployment: `modal deploy deployments/<service_name>.py`
deployments/modal_tts.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deploy Kokoro TTS function to Modal.
2
+
3
+ This script deploys the TTS function to Modal so it can be called
4
+ from the main DeepCritical application.
5
+
6
+ Usage:
7
+ modal deploy deploy_modal_tts.py
8
+
9
+ After deployment, the function will be available at:
10
+ App: deepcritical-tts
11
+ Function: kokoro_tts_function
12
+ """
13
+
14
+ import modal
15
+ import numpy as np
16
+
17
+ # Create Modal app
18
+ app = modal.App("deepcritical-tts")
19
+
20
+ # Define Kokoro TTS dependencies
21
+ KOKORO_DEPENDENCIES = [
22
+ "torch>=2.0.0",
23
+ "transformers>=4.30.0",
24
+ "numpy<2.0",
25
+ ]
26
+
27
+ # Create Modal image with Kokoro
28
+ tts_image = (
29
+ modal.Image.debian_slim(python_version="3.11")
30
+ .apt_install("git") # Install git first for pip install from github
31
+ .pip_install(*KOKORO_DEPENDENCIES)
32
+ .pip_install("git+https://github.com/hexgrad/kokoro.git")
33
+ )
34
+
35
+
36
+ @app.function(
37
+ image=tts_image,
38
+ gpu="T4",
39
+ timeout=60,
40
+ )
41
+ def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, np.ndarray]:
42
+ """Modal GPU function for Kokoro TTS.
43
+
44
+ This function runs on Modal's GPU infrastructure.
45
+ Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS
46
+
47
+ Args:
48
+ text: Text to synthesize
49
+ voice: Voice ID (e.g., af_heart, af_bella, am_michael)
50
+ speed: Speech speed multiplier (0.5-2.0)
51
+
52
+ Returns:
53
+ Tuple of (sample_rate, audio_array)
54
+ """
55
+ import numpy as np
56
+
57
+ try:
58
+ import torch
59
+ from kokoro import KModel, KPipeline
60
+
61
+ # Initialize model (cached on GPU)
62
+ model = KModel().to("cuda").eval()
63
+ pipeline = KPipeline(lang_code=voice[0])
64
+ pack = pipeline.load_voice(voice)
65
+
66
+ # Generate audio - accumulate all chunks
67
+ audio_chunks = []
68
+ for _, ps, _ in pipeline(text, voice, speed):
69
+ ref_s = pack[len(ps) - 1]
70
+ audio = model(ps, ref_s, speed)
71
+ audio_chunks.append(audio.numpy())
72
+
73
+ # Concatenate all audio chunks
74
+ if audio_chunks:
75
+ full_audio = np.concatenate(audio_chunks)
76
+ return (24000, full_audio)
77
+
78
+ # If no audio generated, return empty
79
+ return (24000, np.zeros(1, dtype=np.float32))
80
+
81
+ except ImportError as e:
82
+ raise RuntimeError(
83
+ f"Kokoro not installed: {e}. "
84
+ "Install with: pip install git+https://github.com/hexgrad/kokoro.git"
85
+ ) from e
86
+ except Exception as e:
87
+ raise RuntimeError(f"TTS synthesis failed: {e}") from e
88
+
89
+
90
+ # Optional: Add a test entrypoint
91
+ @app.local_entrypoint()
92
+ def test():
93
+ """Test the TTS function."""
94
+ print("Testing Modal TTS function...")
95
+ sample_rate, audio = kokoro_tts_function.remote(
96
+ "Hello, this is a test.",
97
+ "af_heart",
98
+ 1.0
99
+ )
100
+ print(f"Generated audio: {sample_rate}Hz, shape={audio.shape}")
101
+ print("✓ TTS function works!")
dev/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Development utilities and plugins."""
 
 
src/agents/audio_refiner.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio Refiner Agent - Cleans markdown reports for TTS audio clarity.
2
+
3
+ This agent transforms markdown-formatted research reports into clean,
4
+ audio-friendly plain text suitable for text-to-speech synthesis.
5
+ """
6
+
7
+ import re
8
+ from typing import Optional
9
+
10
+ import structlog
11
+ from pydantic_ai import Agent
12
+
13
+ from src.utils.llm_factory import get_pydantic_ai_model
14
+
15
+ logger = structlog.get_logger(__name__)
16
+
17
+
18
+ class AudioRefiner:
19
+ """Refines markdown reports for optimal TTS audio output.
20
+
21
+ Handles common formatting issues that make text difficult to listen to:
22
+ - Markdown syntax (headers, bold, italic, links)
23
+ - Citations and reference markers
24
+ - Roman numerals in medical contexts
25
+ - Multiple References sections
26
+ - Special characters and formatting artifacts
27
+ """
28
+
29
+ # Roman numeral to integer mapping
30
+ ROMAN_VALUES = {
31
+ 'I': 1, 'V': 5, 'X': 10, 'L': 50,
32
+ 'C': 100, 'D': 500, 'M': 1000
33
+ }
34
+
35
+ # Number to word mapping (1-20, common in medical literature)
36
+ NUMBER_TO_WORD = {
37
+ 1: 'One', 2: 'Two', 3: 'Three', 4: 'Four', 5: 'Five',
38
+ 6: 'Six', 7: 'Seven', 8: 'Eight', 9: 'Nine', 10: 'Ten',
39
+ 11: 'Eleven', 12: 'Twelve', 13: 'Thirteen', 14: 'Fourteen',
40
+ 15: 'Fifteen', 16: 'Sixteen', 17: 'Seventeen', 18: 'Eighteen',
41
+ 19: 'Nineteen', 20: 'Twenty'
42
+ }
43
+
44
+ async def refine_for_audio(self, markdown_text: str, use_llm_polish: bool = False) -> str:
45
+ """Transform markdown report into audio-friendly plain text.
46
+
47
+ Args:
48
+ markdown_text: Markdown-formatted research report
49
+ use_llm_polish: If True, apply LLM-based final polish (optional)
50
+
51
+ Returns:
52
+ Clean plain text optimized for TTS audio
53
+ """
54
+ logger.info("Refining report for audio output", use_llm_polish=use_llm_polish)
55
+
56
+ text = markdown_text
57
+
58
+ # Step 1: Keep only content before first References section
59
+ text = self._remove_references_sections(text)
60
+
61
+ # Step 2: Remove markdown formatting
62
+ text = self._remove_markdown_syntax(text)
63
+
64
+ # Step 3: Convert roman numerals to words
65
+ text = self._convert_roman_numerals(text)
66
+
67
+ # Step 4: Remove citations
68
+ text = self._remove_citations(text)
69
+
70
+ # Step 5: Clean up special characters and artifacts
71
+ text = self._clean_special_characters(text)
72
+
73
+ # Step 6: Normalize whitespace
74
+ text = self._normalize_whitespace(text)
75
+
76
+ # Step 7 (Optional): LLM polish for edge cases
77
+ if use_llm_polish:
78
+ text = await self._llm_polish(text)
79
+
80
+ logger.info(
81
+ "Audio refinement complete",
82
+ original_length=len(markdown_text),
83
+ refined_length=len(text),
84
+ llm_polish_applied=use_llm_polish
85
+ )
86
+
87
+ return text.strip()
88
+
89
+ def _remove_references_sections(self, text: str) -> str:
90
+ """Remove References sections while preserving other content.
91
+
92
+ Removes the References section and its content until the next section
93
+ heading or end of document. Handles multiple References sections.
94
+
95
+ Matches various References heading formats:
96
+ - # References
97
+ - ## References
98
+ - **References:**
99
+ - **Additional References:**
100
+ """
101
+ # Pattern to match References section heading (case-insensitive)
102
+ # Only matches headings that contain "Reference" or "References"
103
+ references_pattern = r'\n(?:#+\s*References?:?\s*\n|\*\*\s*(?:Additional\s+)?References?:?\s*\*\*\s*\n)'
104
+
105
+ # Find all References sections
106
+ while True:
107
+ match = re.search(references_pattern, text, re.IGNORECASE)
108
+ if not match:
109
+ break
110
+
111
+ # Find the start of the References section
112
+ section_start = match.start()
113
+
114
+ # Find the next section (markdown header or bold heading) or end of document
115
+ # Match: "# Header", "## Header", or "**Header**"
116
+ next_section_patterns = [
117
+ r'\n#+\s+\w+', # Markdown headers (# Section, ## Section)
118
+ r'\n\*\*[A-Z][^*]+\*\*', # Bold headings (**Section Name**)
119
+ ]
120
+
121
+ remaining_text = text[match.end():]
122
+ next_section_match = None
123
+
124
+ # Try all patterns and find the earliest match
125
+ earliest_match = None
126
+ for pattern in next_section_patterns:
127
+ m = re.search(pattern, remaining_text)
128
+ if m and (earliest_match is None or m.start() < earliest_match.start()):
129
+ earliest_match = m
130
+
131
+ next_section_match = earliest_match
132
+
133
+ if next_section_match:
134
+ # Remove from References heading to next section
135
+ section_end = match.end() + next_section_match.start()
136
+ else:
137
+ # No next section - remove to end of document
138
+ section_end = len(text)
139
+
140
+ # Remove the References section
141
+ text = text[:section_start] + text[section_end:]
142
+ logger.debug(
143
+ "Removed References section",
144
+ removed_chars=section_end - section_start
145
+ )
146
+
147
+ return text
148
+
149
+ def _remove_markdown_syntax(self, text: str) -> str:
150
+ """Remove markdown formatting syntax."""
151
+
152
+ # Headers (# ## ###)
153
+ text = re.sub(r'^\s*#+\s+', '', text, flags=re.MULTILINE)
154
+
155
+ # Bold (**text** or __text__)
156
+ text = re.sub(r'\*\*([^*]+)\*\*', r'\1', text)
157
+ text = re.sub(r'__([^_]+)__', r'\1', text)
158
+
159
+ # Italic (*text* or _text_)
160
+ text = re.sub(r'\*([^*]+)\*', r'\1', text)
161
+ text = re.sub(r'_([^_]+)_', r'\1', text)
162
+
163
+ # Links [text](url) → text
164
+ text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text)
165
+
166
+ # Inline code `code` → code
167
+ text = re.sub(r'`([^`]+)`', r'\1', text)
168
+
169
+ # Strikethrough ~~text~~
170
+ text = re.sub(r'~~([^~]+)~~', r'\1', text)
171
+
172
+ # Blockquotes (> text)
173
+ text = re.sub(r'^\s*>\s+', '', text, flags=re.MULTILINE)
174
+
175
+ # Horizontal rules (---, ***, ___)
176
+ text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
177
+
178
+ # List markers (-, *, 1., 2.)
179
+ text = re.sub(r'^\s*[-*]\s+', '', text, flags=re.MULTILINE)
180
+ text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
181
+
182
+ return text
183
+
184
+ def _roman_to_int(self, roman: str) -> Optional[int]:
185
+ """Convert roman numeral string to integer.
186
+
187
+ Args:
188
+ roman: Roman numeral string (e.g., 'IV', 'XII')
189
+
190
+ Returns:
191
+ Integer value, or None if invalid roman numeral
192
+ """
193
+ roman = roman.upper()
194
+ result = 0
195
+ prev_value = 0
196
+
197
+ for char in reversed(roman):
198
+ if char not in self.ROMAN_VALUES:
199
+ return None
200
+
201
+ value = self.ROMAN_VALUES[char]
202
+
203
+ # Subtractive notation (IV = 4, IX = 9)
204
+ if value < prev_value:
205
+ result -= value
206
+ else:
207
+ result += value
208
+
209
+ prev_value = value
210
+
211
+ return result
212
+
213
+ def _int_to_word(self, num: int) -> str:
214
+ """Convert integer to word representation.
215
+
216
+ Args:
217
+ num: Integer to convert (1-20 supported)
218
+
219
+ Returns:
220
+ Word representation (e.g., 'One', 'Twelve')
221
+ """
222
+ if num in self.NUMBER_TO_WORD:
223
+ return self.NUMBER_TO_WORD[num]
224
+ else:
225
+ # For numbers > 20, just return the digit
226
+ return str(num)
227
+
228
+ def _convert_roman_numerals(self, text: str) -> str:
229
+ """Convert roman numerals to words for better TTS pronunciation.
230
+
231
+ Handles patterns like:
232
+ - Phase I, Phase II, Phase III
233
+ - Trial I, Trial II
234
+ - Type I, Type II
235
+ - Stage I, Stage II
236
+ - Standalone I, II, III (with word boundaries)
237
+ """
238
+
239
+ def replace_roman(match):
240
+ """Callback to replace matched roman numeral."""
241
+ prefix = match.group(1) # Word before roman numeral (if any)
242
+ roman = match.group(2) # The roman numeral
243
+
244
+ # Convert to integer
245
+ num = self._roman_to_int(roman)
246
+ if num is None:
247
+ return match.group(0) # Return original if invalid
248
+
249
+ # Convert to word
250
+ word = self._int_to_word(num)
251
+
252
+ # Return with prefix if present
253
+ if prefix:
254
+ return f"{prefix} {word}"
255
+ else:
256
+ return word
257
+
258
+ # Pattern: Optional word + space + roman numeral
259
+ # Matches: "Phase I", "Trial II", standalone "I", "II"
260
+ # Uses word boundaries to avoid matching "I" in "INVALID"
261
+ pattern = r'\b(Phase|Trial|Type|Stage|Class|Group|Arm|Cohort)?\s*([IVXLCDM]+)\b'
262
+
263
+ text = re.sub(pattern, replace_roman, text)
264
+
265
+ return text
266
+
267
+ def _remove_citations(self, text: str) -> str:
268
+ """Remove citation markers and references."""
269
+
270
+ # Numbered citations [1], [2], [1,2], [1-3]
271
+ text = re.sub(r'\[\d+(?:[-,]\d+)*\]', '', text)
272
+
273
+ # Author citations (Smith et al., 2023) or (Smith et al. 2023)
274
+ text = re.sub(r'\([A-Z][a-z]+\s+et\s+al\.?,?\s+\d{4}\)', '', text)
275
+
276
+ # Simple year citations (2023)
277
+ text = re.sub(r'\(\d{4}\)', '', text)
278
+
279
+ # Author-year (Smith, 2023)
280
+ text = re.sub(r'\([A-Z][a-z]+,?\s+\d{4}\)', '', text)
281
+
282
+ # Footnote markers (¹, ², ³)
283
+ text = re.sub(r'[¹²³⁴⁵⁶⁷⁸⁹⁰]+', '', text)
284
+
285
+ return text
286
+
287
+ def _clean_special_characters(self, text: str) -> str:
288
+ """Clean up special characters and formatting artifacts."""
289
+
290
+ # Replace em dashes with regular dashes
291
+ text = text.replace('\u2014', '-') # em dash
292
+ text = text.replace('\u2013', '-') # en dash
293
+
294
+ # Replace smart quotes with regular quotes
295
+ text = text.replace('\u201c', '"') # left double quote
296
+ text = text.replace('\u201d', '"') # right double quote
297
+ text = text.replace('\u2018', "'") # left single quote
298
+ text = text.replace('\u2019', "'") # right single quote
299
+
300
+ # Remove excessive punctuation (!!!, ???)
301
+ text = re.sub(r'([!?]){2,}', r'\1', text)
302
+
303
+ # Remove asterisks used for footnotes
304
+ text = re.sub(r'\*+', '', text)
305
+
306
+ # Remove hash symbols (from headers)
307
+ text = text.replace('#', '')
308
+
309
+ # Remove excessive dots (...)
310
+ text = re.sub(r'\.{4,}', '...', text)
311
+
312
+ return text
313
+
314
+ def _normalize_whitespace(self, text: str) -> str:
315
+ """Normalize whitespace for clean audio output."""
316
+
317
+ # Replace multiple spaces with single space
318
+ text = re.sub(r' {2,}', ' ', text)
319
+
320
+ # Replace multiple newlines with double newline (paragraph break)
321
+ text = re.sub(r'\n{3,}', '\n\n', text)
322
+
323
+ # Remove trailing/leading whitespace from lines
324
+ text = '\n'.join(line.strip() for line in text.split('\n'))
325
+
326
+ # Remove empty lines at start/end
327
+ text = text.strip()
328
+
329
+ return text
330
+
331
+ async def _llm_polish(self, text: str) -> str:
332
+ """Apply LLM-based final polish to catch edge cases.
333
+
334
+ This is a lightweight pass that removes any remaining formatting
335
+ artifacts the rule-based methods might have missed.
336
+
337
+ Args:
338
+ text: Pre-cleaned text from rule-based methods
339
+
340
+ Returns:
341
+ Final polished text ready for TTS
342
+ """
343
+ try:
344
+ # Create a simple agent for text cleanup
345
+ model = get_pydantic_ai_model()
346
+ polish_agent = Agent(
347
+ model=model,
348
+ system_prompt=(
349
+ "You are a text cleanup assistant. Your ONLY job is to remove "
350
+ "any remaining formatting artifacts (markdown, citations, special "
351
+ "characters) that make text unsuitable for text-to-speech audio. "
352
+ "DO NOT rewrite, improve, or change the content. "
353
+ "DO NOT add explanations. "
354
+ "ONLY output the cleaned text."
355
+ ),
356
+ )
357
+
358
+ # Run asynchronously
359
+ result = await polish_agent.run(
360
+ f"Clean this text for audio (remove any formatting artifacts):\n\n{text}"
361
+ )
362
+
363
+ polished_text = result.output.strip()
364
+
365
+ logger.info(
366
+ "llm_polish_applied",
367
+ original_length=len(text),
368
+ polished_length=len(polished_text)
369
+ )
370
+
371
+ return polished_text
372
+
373
+ except Exception as e:
374
+ logger.warning(
375
+ "llm_polish_failed",
376
+ error=str(e),
377
+ message="Falling back to rule-based output"
378
+ )
379
+ # Graceful fallback: return original text if LLM fails
380
+ return text
381
+
382
+
383
+ # Singleton instance for easy import
384
+ audio_refiner = AudioRefiner()
385
+
386
+
387
+ async def refine_text_for_audio(markdown_text: str, use_llm_polish: bool = False) -> str:
388
+ """Convenience function to refine markdown text for audio.
389
+
390
+ Args:
391
+ markdown_text: Markdown-formatted text
392
+ use_llm_polish: If True, apply LLM-based final polish (optional)
393
+
394
+ Returns:
395
+ Audio-friendly plain text
396
+ """
397
+ return await audio_refiner.refine_for_audio(markdown_text, use_llm_polish=use_llm_polish)
src/app.py CHANGED
@@ -18,6 +18,7 @@ import structlog
18
 
19
  from src.agent_factory.judges import HFInferenceJudgeHandler, JudgeHandler, MockJudgeHandler
20
  from src.orchestrator_factory import create_orchestrator
 
21
  from src.services.multimodal_processing import get_multimodal_service
22
  from src.utils.config import settings
23
  from src.utils.models import AgentEvent, OrchestratorConfig
@@ -446,6 +447,7 @@ async def research_agent(
446
  enable_audio_input: bool = True,
447
  tts_voice: str = "af_heart",
448
  tts_speed: float = 1.0,
 
449
  web_search_provider: str = "auto",
450
  oauth_token: gr.OAuthToken | None = None,
451
  oauth_profile: gr.OAuthProfile | None = None,
@@ -465,6 +467,7 @@ async def research_agent(
465
  enable_audio_input: Whether to process audio inputs
466
  tts_voice: TTS voice selection
467
  tts_speed: TTS speech speed
 
468
  web_search_provider: Web search provider selection
469
  oauth_token: Gradio OAuth token (None if user not logged in)
470
  oauth_profile: Gradio OAuth profile (None if user not logged in)
@@ -585,17 +588,23 @@ async def research_agent(
585
  # Optional: Generate audio output if enabled
586
  if settings.enable_audio_output and settings.modal_available:
587
  try:
588
- from src.services.tts_modal import get_tts_service
589
-
590
- tts_service = get_tts_service()
591
  # Get the last message from history for TTS
592
  last_message = history[-1].get("content", "") if history else processed_text
593
  if last_message:
594
- await tts_service.synthesize_async(
595
- text=last_message,
596
- voice=tts_voice,
597
- speed=tts_speed,
598
- )
 
 
 
 
 
 
 
 
599
  except Exception as e:
600
  logger.warning("audio_synthesis_failed", error=str(e))
601
  # Continue without audio output
@@ -1081,6 +1090,13 @@ def create_demo() -> gr.Blocks:
1081
  interactive=False, # GPU type set at function definition time, requires restart
1082
  )
1083
 
 
 
 
 
 
 
 
1084
  # Audio output component (for TTS response) - moved to sidebar
1085
  audio_output = gr.Audio(
1086
  label="🔊 Audio Response",
@@ -1091,18 +1107,19 @@ def create_demo() -> gr.Blocks:
1091
  # This must be after audio_output is defined
1092
  def update_tts_visibility(
1093
  enabled: bool,
1094
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
1095
  """Update visibility of TTS components based on enable checkbox."""
1096
  return (
1097
  gr.update(visible=enabled),
1098
  gr.update(visible=enabled),
1099
  gr.update(visible=enabled),
 
1100
  )
1101
 
1102
  enable_audio_output_checkbox.change(
1103
  fn=update_tts_visibility,
1104
  inputs=[enable_audio_output_checkbox],
1105
- outputs=[tts_voice_dropdown, tts_speed_slider, audio_output],
1106
  )
1107
 
1108
  # Chat interface with multimodal support
@@ -1196,6 +1213,7 @@ def create_demo() -> gr.Blocks:
1196
  enable_audio_input_checkbox,
1197
  tts_voice_dropdown,
1198
  tts_speed_slider,
 
1199
  web_search_provider_dropdown,
1200
  # Note: gr.OAuthToken and gr.OAuthProfile are automatically passed as function parameters
1201
  ],
 
18
 
19
  from src.agent_factory.judges import HFInferenceJudgeHandler, JudgeHandler, MockJudgeHandler
20
  from src.orchestrator_factory import create_orchestrator
21
+ from src.services.audio_processing import get_audio_service
22
  from src.services.multimodal_processing import get_multimodal_service
23
  from src.utils.config import settings
24
  from src.utils.models import AgentEvent, OrchestratorConfig
 
447
  enable_audio_input: bool = True,
448
  tts_voice: str = "af_heart",
449
  tts_speed: float = 1.0,
450
+ tts_use_llm_polish: bool = False,
451
  web_search_provider: str = "auto",
452
  oauth_token: gr.OAuthToken | None = None,
453
  oauth_profile: gr.OAuthProfile | None = None,
 
467
  enable_audio_input: Whether to process audio inputs
468
  tts_voice: TTS voice selection
469
  tts_speed: TTS speech speed
470
+ tts_use_llm_polish: Apply LLM-based final polish to audio text (costs API calls)
471
  web_search_provider: Web search provider selection
472
  oauth_token: Gradio OAuth token (None if user not logged in)
473
  oauth_profile: Gradio OAuth profile (None if user not logged in)
 
588
  # Optional: Generate audio output if enabled
589
  if settings.enable_audio_output and settings.modal_available:
590
  try:
591
+ audio_service = get_audio_service()
 
 
592
  # Get the last message from history for TTS
593
  last_message = history[-1].get("content", "") if history else processed_text
594
  if last_message:
595
+ # Temporarily override tts_use_llm_polish setting from UI
596
+ original_llm_polish = settings.tts_use_llm_polish
597
+ try:
598
+ settings.tts_use_llm_polish = tts_use_llm_polish
599
+ # Use UI-configured voice and speed, fallback to settings defaults
600
+ await audio_service.generate_audio_output(
601
+ text=last_message,
602
+ voice=tts_voice or settings.tts_voice,
603
+ speed=tts_speed if tts_speed else settings.tts_speed,
604
+ )
605
+ finally:
606
+ # Restore original setting
607
+ settings.tts_use_llm_polish = original_llm_polish
608
  except Exception as e:
609
  logger.warning("audio_synthesis_failed", error=str(e))
610
  # Continue without audio output
 
1090
  interactive=False, # GPU type set at function definition time, requires restart
1091
  )
1092
 
1093
+ tts_use_llm_polish_checkbox = gr.Checkbox(
1094
+ value=settings.tts_use_llm_polish,
1095
+ label="Use LLM Polish for Audio",
1096
+ info="Apply LLM-based final polish to remove remaining formatting artifacts (costs API calls)",
1097
+ visible=settings.enable_audio_output,
1098
+ )
1099
+
1100
  # Audio output component (for TTS response) - moved to sidebar
1101
  audio_output = gr.Audio(
1102
  label="🔊 Audio Response",
 
1107
  # This must be after audio_output is defined
1108
  def update_tts_visibility(
1109
  enabled: bool,
1110
+ ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any]]:
1111
  """Update visibility of TTS components based on enable checkbox."""
1112
  return (
1113
  gr.update(visible=enabled),
1114
  gr.update(visible=enabled),
1115
  gr.update(visible=enabled),
1116
+ gr.update(visible=enabled),
1117
  )
1118
 
1119
  enable_audio_output_checkbox.change(
1120
  fn=update_tts_visibility,
1121
  inputs=[enable_audio_output_checkbox],
1122
+ outputs=[tts_voice_dropdown, tts_speed_slider, tts_use_llm_polish_checkbox, audio_output],
1123
  )
1124
 
1125
  # Chat interface with multimodal support
 
1213
  enable_audio_input_checkbox,
1214
  tts_voice_dropdown,
1215
  tts_speed_slider,
1216
+ tts_use_llm_polish_checkbox,
1217
  web_search_provider_dropdown,
1218
  # Note: gr.OAuthToken and gr.OAuthProfile are automatically passed as function parameters
1219
  ],
src/services/audio_processing.py CHANGED
@@ -6,6 +6,7 @@ from typing import Any
6
  import numpy as np
7
  import structlog
8
 
 
9
  from src.services.stt_gradio import STTService, get_stt_service
10
  from src.utils.config import settings
11
 
@@ -85,7 +86,7 @@ class AudioService:
85
  """Generate audio output from text.
86
 
87
  Args:
88
- text: Text to synthesize
89
  voice: Voice ID (default: settings.tts_voice)
90
  speed: Speech speed (default: settings.tts_speed)
91
 
@@ -101,11 +102,22 @@ class AudioService:
101
  return None
102
 
103
  try:
 
 
 
 
 
 
 
 
 
 
 
104
  # Use provided voice/speed or fallback to settings defaults
105
  voice = voice if voice else settings.tts_voice
106
  speed = speed if speed is not None else settings.tts_speed
107
 
108
- audio_output = await self.tts.synthesize_async(text, voice, speed) # type: ignore[misc]
109
 
110
  if audio_output:
111
  logger.info(
 
6
  import numpy as np
7
  import structlog
8
 
9
+ from src.agents.audio_refiner import audio_refiner
10
  from src.services.stt_gradio import STTService, get_stt_service
11
  from src.utils.config import settings
12
 
 
86
  """Generate audio output from text.
87
 
88
  Args:
89
+ text: Text to synthesize (markdown will be cleaned for audio)
90
  voice: Voice ID (default: settings.tts_voice)
91
  speed: Speech speed (default: settings.tts_speed)
92
 
 
102
  return None
103
 
104
  try:
105
+ # Refine text for audio (remove markdown, citations, etc.)
106
+ # Use LLM polish if enabled in settings
107
+ refined_text = await audio_refiner.refine_for_audio(
108
+ text,
109
+ use_llm_polish=settings.tts_use_llm_polish
110
+ )
111
+ logger.info("text_refined_for_audio",
112
+ original_length=len(text),
113
+ refined_length=len(refined_text),
114
+ llm_polish_enabled=settings.tts_use_llm_polish)
115
+
116
  # Use provided voice/speed or fallback to settings defaults
117
  voice = voice if voice else settings.tts_voice
118
  speed = speed if speed is not None else settings.tts_speed
119
 
120
+ audio_output = await self.tts.synthesize_async(refined_text, voice, speed) # type: ignore[misc]
121
 
122
  if audio_output:
123
  logger.info(
src/services/tts_modal.py CHANGED
@@ -1,12 +1,18 @@
1
  """Text-to-Speech service using Kokoro 82M via Modal GPU."""
2
 
3
  import asyncio
 
4
  from functools import lru_cache
5
  from typing import Any
6
 
7
  import numpy as np
8
  import structlog
9
 
 
 
 
 
 
10
  from src.utils.config import settings
11
  from src.utils.exceptions import ConfigurationError
12
 
@@ -24,39 +30,52 @@ KOKORO_DEPENDENCIES = [
24
  # Modal app and function definitions (module-level for Modal)
25
  _modal_app: Any | None = None
26
  _tts_function: Any | None = None
 
27
 
28
 
29
  def _get_modal_app() -> Any:
30
- """Get or create Modal app instance."""
 
 
 
 
31
  global _modal_app
32
  if _modal_app is None:
33
  try:
34
  import modal
35
 
36
- # Validate Modal credentials before attempting lookup
37
- if not settings.modal_available:
 
 
 
 
38
  raise ConfigurationError(
39
- "Modal credentials not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables."
 
40
  )
41
 
42
  # Validate token ID format (Modal token IDs are typically UUIDs or specific formats)
43
- token_id = settings.modal_token_id
44
- if token_id:
45
- # Basic validation: token ID should not be empty and should be a reasonable length
46
- if len(token_id.strip()) < 10:
47
- raise ConfigurationError(
48
- f"Modal token ID appears malformed (too short: {len(token_id)} chars). "
49
- "Token ID should be a valid Modal token identifier."
50
- )
 
 
 
51
 
52
  try:
53
- _modal_app = modal.App.lookup("deepcritical-tts", create_if_missing=True)
54
  except Exception as e:
55
  error_msg = str(e).lower()
56
  if "token" in error_msg or "malformed" in error_msg or "invalid" in error_msg:
57
  raise ConfigurationError(
58
  f"Modal token validation failed: {e}. "
59
- "Please check that MODAL_TOKEN_ID and MODAL_TOKEN_SECRET are correctly set."
60
  ) from e
61
  raise
62
  except ImportError as e:
@@ -69,23 +88,92 @@ def _get_modal_app() -> Any:
69
  # Define Modal image with Kokoro dependencies (module-level)
70
  def _get_tts_image() -> Any:
71
  """Get Modal image with Kokoro dependencies."""
 
 
 
 
72
  try:
73
  import modal
74
 
75
- return (
76
  modal.Image.debian_slim(python_version="3.11")
77
  .pip_install(*KOKORO_DEPENDENCIES)
78
  .pip_install("git+https://github.com/hexgrad/kokoro.git")
79
  )
 
80
  except ImportError:
81
  return None
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def _setup_modal_function() -> None:
85
  """Setup Modal GPU function for TTS (called once, lazy initialization).
86
 
87
- Note: GPU type is set at function definition time. Changes to settings.tts_gpu
88
- require app restart to take effect.
 
 
89
  """
90
  global _tts_function
91
 
@@ -93,80 +181,27 @@ def _setup_modal_function() -> None:
93
  return # Already set up
94
 
95
  try:
96
- app = _get_modal_app()
97
- tts_image = _get_tts_image()
98
-
99
- if tts_image is None:
100
- raise ConfigurationError("Modal image setup failed")
101
-
102
- # Get GPU and timeout from settings (with defaults)
103
- # Note: These are evaluated at function definition time, not at call time
104
- # Changes to settings require app restart
105
- gpu_type = getattr(settings, "tts_gpu", None) or "T4"
106
- timeout_seconds = getattr(settings, "tts_timeout", None) or 60
107
-
108
- # Define GPU function at module level (required by Modal)
109
- # Modal functions are immutable once defined, so GPU changes require restart
110
- @app.function( # type: ignore[misc]
111
- image=tts_image,
112
- gpu=gpu_type,
113
- timeout=timeout_seconds,
114
- )
115
- def kokoro_tts_function(
116
- text: str, voice: str, speed: float
117
- ) -> tuple[int, np.ndarray[Any, Any]]: # type: ignore[type-arg]
118
- """Modal GPU function for Kokoro TTS.
119
-
120
- This function runs on Modal's GPU infrastructure.
121
- Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS
122
- Reference: https://huggingface.co/spaces/hexgrad/Kokoro-TTS/raw/main/app.py
123
- """
124
- import numpy as np
125
-
126
- # Import Kokoro inside function (lazy load)
127
- try:
128
- from kokoro import KModel, KPipeline
129
-
130
- # Initialize model (cached on GPU)
131
- model = KModel().to("cuda").eval()
132
- pipeline = KPipeline(lang_code=voice[0])
133
- pack = pipeline.load_voice(voice)
134
-
135
- # Generate audio
136
- for _, ps, _ in pipeline(text, voice, speed):
137
- ref_s = pack[len(ps) - 1]
138
- audio = model(ps, ref_s, speed)
139
- return (24000, audio.numpy())
140
-
141
- # If no audio generated, return empty
142
- return (24000, np.zeros(1, dtype=np.float32))
143
-
144
- except ImportError as e:
145
- raise ConfigurationError(
146
- "Kokoro not installed. Install with: pip install git+https://github.com/hexgrad/kokoro.git"
147
- ) from e
148
- except Exception as e:
149
- raise ConfigurationError(f"TTS synthesis failed: {e}") from e
150
-
151
- # Store function reference for remote calls
152
- _tts_function = kokoro_tts_function
153
 
154
- # Verify function is properly attached to app
155
- if not hasattr(app, kokoro_tts_function.__name__):
156
- logger.warning(
157
- "modal_function_not_attached", function_name=kokoro_tts_function.__name__
158
- )
 
159
 
160
  logger.info(
161
- "modal_tts_function_setup_complete",
162
- gpu=gpu_type,
163
- timeout=timeout_seconds,
164
- function_name=kokoro_tts_function.__name__,
165
  )
166
 
167
  except Exception as e:
168
  logger.error("modal_tts_function_setup_failed", error=str(e))
169
- raise ConfigurationError(f"Failed to setup Modal TTS function: {e}") from e
 
 
 
170
 
171
 
172
  class ModalTTSExecutor:
@@ -180,13 +215,17 @@ class ModalTTSExecutor:
180
  """Initialize Modal TTS executor.
181
 
182
  Note:
183
- Logs a warning if Modal credentials are not configured.
184
- Execution will fail at runtime without valid credentials.
185
  """
186
- # Check for Modal credentials
187
- if not settings.modal_available:
 
 
 
188
  logger.warning(
189
- "Modal credentials not found. TTS will not be available unless modal setup is run."
 
190
  )
191
 
192
  def synthesize(
@@ -195,7 +234,7 @@ class ModalTTSExecutor:
195
  voice: str = "af_heart",
196
  speed: float = 1.0,
197
  timeout: int = 60,
198
- ) -> tuple[int, np.ndarray[Any, Any]]: # type: ignore[type-arg]
199
  """Synthesize text to speech using Kokoro on Modal GPU.
200
 
201
  Args:
@@ -226,7 +265,7 @@ class ModalTTSExecutor:
226
  "tts_synthesis_complete", sample_rate=result[0], audio_shape=result[1].shape
227
  )
228
 
229
- return result # type: ignore[no-any-return]
230
 
231
  except Exception as e:
232
  logger.error("tts_synthesis_failed", error=str(e), error_type=type(e).__name__)
@@ -237,9 +276,19 @@ class TTSService:
237
  """TTS service wrapper for async usage."""
238
 
239
  def __init__(self) -> None:
240
- """Initialize TTS service."""
241
- if not settings.modal_available:
242
- raise ConfigurationError("Modal credentials required for TTS")
 
 
 
 
 
 
 
 
 
 
243
  self.executor = ModalTTSExecutor()
244
 
245
  async def synthesize_async(
@@ -247,7 +296,7 @@ class TTSService:
247
  text: str,
248
  voice: str = "af_heart",
249
  speed: float = 1.0,
250
- ) -> tuple[int, np.ndarray[Any, Any]] | None: # type: ignore[type-arg]
251
  """Async wrapper for TTS synthesis.
252
 
253
  Args:
 
1
  """Text-to-Speech service using Kokoro 82M via Modal GPU."""
2
 
3
  import asyncio
4
+ import os
5
  from functools import lru_cache
6
  from typing import Any
7
 
8
  import numpy as np
9
  import structlog
10
 
11
+ # Load .env file BEFORE importing Modal SDK
12
+ # Modal SDK reads MODAL_TOKEN_ID and MODAL_TOKEN_SECRET from environment on import
13
+ from dotenv import load_dotenv
14
+ load_dotenv()
15
+
16
  from src.utils.config import settings
17
  from src.utils.exceptions import ConfigurationError
18
 
 
30
  # Modal app and function definitions (module-level for Modal)
31
  _modal_app: Any | None = None
32
  _tts_function: Any | None = None
33
+ _tts_image: Any | None = None
34
 
35
 
36
  def _get_modal_app() -> Any:
37
+ """Get or create Modal app instance.
38
+
39
+ Retrieves Modal credentials directly from environment variables (.env file)
40
+ instead of relying on settings configuration.
41
+ """
42
  global _modal_app
43
  if _modal_app is None:
44
  try:
45
  import modal
46
 
47
+ # Get credentials directly from environment variables
48
+ token_id = os.getenv("MODAL_TOKEN_ID")
49
+ token_secret = os.getenv("MODAL_TOKEN_SECRET")
50
+
51
+ # Validate Modal credentials
52
+ if not token_id or not token_secret:
53
  raise ConfigurationError(
54
+ "Modal credentials not found in environment. "
55
+ "Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file."
56
  )
57
 
58
  # Validate token ID format (Modal token IDs are typically UUIDs or specific formats)
59
+ if len(token_id.strip()) < 10:
60
+ raise ConfigurationError(
61
+ f"Modal token ID appears malformed (too short: {len(token_id)} chars). "
62
+ "Token ID should be a valid Modal token identifier."
63
+ )
64
+
65
+ logger.info(
66
+ "modal_credentials_loaded",
67
+ token_id_prefix=token_id[:8] + "...", # Log prefix for debugging
68
+ has_secret=bool(token_secret),
69
+ )
70
 
71
  try:
72
+ _modal_app = modal.App("deepcritical-tts")
73
  except Exception as e:
74
  error_msg = str(e).lower()
75
  if "token" in error_msg or "malformed" in error_msg or "invalid" in error_msg:
76
  raise ConfigurationError(
77
  f"Modal token validation failed: {e}. "
78
+ "Please check that MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env are correctly set."
79
  ) from e
80
  raise
81
  except ImportError as e:
 
88
  # Define Modal image with Kokoro dependencies (module-level)
89
  def _get_tts_image() -> Any:
90
  """Get Modal image with Kokoro dependencies."""
91
+ global _tts_image
92
+ if _tts_image is not None:
93
+ return _tts_image
94
+
95
  try:
96
  import modal
97
 
98
+ _tts_image = (
99
  modal.Image.debian_slim(python_version="3.11")
100
  .pip_install(*KOKORO_DEPENDENCIES)
101
  .pip_install("git+https://github.com/hexgrad/kokoro.git")
102
  )
103
+ return _tts_image
104
  except ImportError:
105
  return None
106
 
107
 
108
+ # Modal TTS function - Using serialized=True to allow dynamic creation
109
+ # This will be initialized lazily when _setup_modal_function() is called
110
+ def _create_tts_function() -> Any:
111
+ """Create the Modal TTS function using serialized=True.
112
+
113
+ The serialized=True parameter allows the function to be defined outside
114
+ of global scope, which is necessary for dynamic initialization.
115
+ """
116
+ app = _get_modal_app()
117
+ tts_image = _get_tts_image()
118
+
119
+ if tts_image is None:
120
+ raise ConfigurationError("Modal image setup failed")
121
+
122
+ # Get GPU and timeout from settings (with defaults)
123
+ gpu_type = getattr(settings, "tts_gpu", None) or "T4"
124
+ timeout_seconds = getattr(settings, "tts_timeout", None) or 60
125
+
126
+ @app.function(
127
+ image=tts_image,
128
+ gpu=gpu_type,
129
+ timeout=timeout_seconds,
130
+ serialized=True, # Allow function to be defined outside global scope
131
+ )
132
+ def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, np.ndarray]:
133
+ """Modal GPU function for Kokoro TTS.
134
+
135
+ This function runs on Modal's GPU infrastructure.
136
+ Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS
137
+ Reference: https://huggingface.co/spaces/hexgrad/Kokoro-TTS/raw/main/app.py
138
+ """
139
+ import numpy as np
140
+
141
+ # Import Kokoro inside function (lazy load)
142
+ try:
143
+ import torch
144
+ from kokoro import KModel, KPipeline
145
+
146
+ # Initialize model (cached on GPU)
147
+ model = KModel().to("cuda").eval()
148
+ pipeline = KPipeline(lang_code=voice[0])
149
+ pack = pipeline.load_voice(voice)
150
+
151
+ # Generate audio
152
+ for _, ps, _ in pipeline(text, voice, speed):
153
+ ref_s = pack[len(ps) - 1]
154
+ audio = model(ps, ref_s, speed)
155
+ return (24000, audio.numpy())
156
+
157
+ # If no audio generated, return empty
158
+ return (24000, np.zeros(1, dtype=np.float32))
159
+
160
+ except ImportError as e:
161
+ raise ConfigurationError(
162
+ "Kokoro not installed. Install with: pip install git+https://github.com/hexgrad/kokoro.git"
163
+ ) from e
164
+ except Exception as e:
165
+ raise ConfigurationError(f"TTS synthesis failed: {e}") from e
166
+
167
+ return kokoro_tts_function
168
+
169
+
170
  def _setup_modal_function() -> None:
171
  """Setup Modal GPU function for TTS (called once, lazy initialization).
172
 
173
+ Looks up the deployed Modal function instead of creating a new one.
174
+ This requires the 'deepcritical-tts' app to be deployed on Modal.
175
+
176
+ To deploy: modal deploy <script_with_tts_function>.py
177
  """
178
  global _tts_function
179
 
 
181
  return # Already set up
182
 
183
  try:
184
+ import modal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ # Look up the deployed function from the Modal server
187
+ # This requires the app to be deployed: modal deploy tts_modal.py
188
+ _tts_function = modal.Function.from_name(
189
+ "deepcritical-tts",
190
+ "kokoro_tts_function"
191
+ )
192
 
193
  logger.info(
194
+ "modal_tts_function_lookup_complete",
195
+ app_name="deepcritical-tts",
196
+ function_name="kokoro_tts_function",
 
197
  )
198
 
199
  except Exception as e:
200
  logger.error("modal_tts_function_setup_failed", error=str(e))
201
+ raise ConfigurationError(
202
+ f"Failed to lookup Modal TTS function: {e}. "
203
+ "Make sure the 'deepcritical-tts' app is deployed on Modal."
204
+ ) from e
205
 
206
 
207
  class ModalTTSExecutor:
 
215
  """Initialize Modal TTS executor.
216
 
217
  Note:
218
+ Logs a warning if Modal credentials are not configured in environment.
219
+ Execution will fail at runtime without valid credentials in .env file.
220
  """
221
+ # Check for Modal credentials directly from environment
222
+ token_id = os.getenv("MODAL_TOKEN_ID")
223
+ token_secret = os.getenv("MODAL_TOKEN_SECRET")
224
+
225
+ if not token_id or not token_secret:
226
  logger.warning(
227
+ "Modal credentials not found in environment. "
228
+ "TTS will not be available. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file."
229
  )
230
 
231
  def synthesize(
 
234
  voice: str = "af_heart",
235
  speed: float = 1.0,
236
  timeout: int = 60,
237
+ ) -> tuple[int, np.ndarray]:
238
  """Synthesize text to speech using Kokoro on Modal GPU.
239
 
240
  Args:
 
265
  "tts_synthesis_complete", sample_rate=result[0], audio_shape=result[1].shape
266
  )
267
 
268
+ return result
269
 
270
  except Exception as e:
271
  logger.error("tts_synthesis_failed", error=str(e), error_type=type(e).__name__)
 
276
  """TTS service wrapper for async usage."""
277
 
278
  def __init__(self) -> None:
279
+ """Initialize TTS service.
280
+
281
+ Validates Modal credentials from environment variables (.env file).
282
+ """
283
+ # Check credentials directly from environment
284
+ token_id = os.getenv("MODAL_TOKEN_ID")
285
+ token_secret = os.getenv("MODAL_TOKEN_SECRET")
286
+
287
+ if not token_id or not token_secret:
288
+ raise ConfigurationError(
289
+ "Modal credentials required for TTS. "
290
+ "Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file."
291
+ )
292
  self.executor = ModalTTSExecutor()
293
 
294
  async def synthesize_async(
 
296
  text: str,
297
  voice: str = "af_heart",
298
  speed: float = 1.0,
299
+ ) -> tuple[int, np.ndarray] | None:
300
  """Async wrapper for TTS synthesis.
301
 
302
  Args:
src/utils/config.py CHANGED
@@ -172,6 +172,10 @@ class Settings(BaseSettings):
172
  le=2.0,
173
  description="TTS speech speed multiplier (0.5x to 2.0x)",
174
  )
 
 
 
 
175
  tts_gpu: str | None = Field(
176
  default=None,
177
  description="Modal GPU type for TTS (T4, A10, A100, L4, L40S). None uses default T4.",
 
172
  le=2.0,
173
  description="TTS speech speed multiplier (0.5x to 2.0x)",
174
  )
175
+ tts_use_llm_polish: bool = Field(
176
+ default=False,
177
+ description="Use LLM for final text polish before TTS (optional, costs API calls)",
178
+ )
179
  tts_gpu: str | None = Field(
180
  default=None,
181
  description="Modal GPU type for TTS (T4, A10, A100, L4, L40S). None uses default T4.",
tests/unit/agents/test_audio_refiner.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for AudioRefiner agent."""
2
+
3
+ import pytest
4
+ from unittest.mock import AsyncMock, Mock, patch
5
+
6
+ from src.agents.audio_refiner import AudioRefiner, refine_text_for_audio
7
+
8
+
9
+ class TestAudioRefiner:
10
+ """Test suite for AudioRefiner functionality."""
11
+
12
+ @pytest.fixture
13
+ def refiner(self):
14
+ """Create AudioRefiner instance."""
15
+ return AudioRefiner()
16
+
17
+ def test_remove_markdown_headers(self, refiner):
18
+ """Test removal of markdown headers."""
19
+ text = """# Main Title
20
+ ## Subtitle
21
+ ### Section
22
+ Content here"""
23
+ result = refiner._remove_markdown_syntax(text)
24
+ assert "#" not in result
25
+ assert "Main Title" in result
26
+ assert "Subtitle" in result
27
+
28
+ def test_remove_bold_italic(self, refiner):
29
+ """Test removal of bold and italic formatting."""
30
+ text = "**Bold text** and *italic text* and __another bold__"
31
+ result = refiner._remove_markdown_syntax(text)
32
+ assert "**" not in result
33
+ assert "*" not in result
34
+ assert "__" not in result
35
+ assert "Bold text" in result
36
+ assert "italic text" in result
37
+
38
+ def test_remove_links(self, refiner):
39
+ """Test removal of markdown links."""
40
+ text = "Check [this link](https://example.com) for details"
41
+ result = refiner._remove_markdown_syntax(text)
42
+ assert "[" not in result
43
+ assert "]" not in result
44
+ assert "https://" not in result
45
+ assert "this link" in result
46
+
47
+ def test_remove_citations_numbered(self, refiner):
48
+ """Test removal of numbered citations."""
49
+ text = "Research shows [1] that metformin [2,3] works [4-6]."
50
+ result = refiner._remove_citations(text)
51
+ assert "[1]" not in result
52
+ assert "[2,3]" not in result
53
+ assert "[4-6]" not in result
54
+ assert "Research shows" in result
55
+
56
+ def test_remove_citations_author_year(self, refiner):
57
+ """Test removal of author-year citations."""
58
+ text = "Studies (Smith et al., 2023) and (Jones, 2022) confirm this."
59
+ result = refiner._remove_citations(text)
60
+ assert "(Smith et al., 2023)" not in result
61
+ assert "(Jones, 2022)" not in result
62
+ assert "Studies" in result
63
+ assert "confirm this" in result
64
+
65
+ def test_remove_first_references_section(self, refiner):
66
+ """Test that References sections are removed while preserving other content."""
67
+ text = """Main content here.
68
+
69
+ # References
70
+ [1] First reference
71
+ [2] Second reference
72
+
73
+ # More Content
74
+ This should remain.
75
+
76
+ ## References
77
+ This second References should also be removed."""
78
+
79
+ result = refiner._remove_references_sections(text)
80
+ assert "Main content here" in result
81
+ assert "References" not in result
82
+ assert "First reference" not in result
83
+ assert "More Content" in result # Content after References should be preserved
84
+ assert "This should remain" in result
85
+ assert "second References should also be removed" not in result # Second References section removed
86
+
87
+ def test_roman_to_int_conversion(self, refiner):
88
+ """Test roman numeral to integer conversion."""
89
+ assert refiner._roman_to_int("I") == 1
90
+ assert refiner._roman_to_int("II") == 2
91
+ assert refiner._roman_to_int("III") == 3
92
+ assert refiner._roman_to_int("IV") == 4
93
+ assert refiner._roman_to_int("V") == 5
94
+ assert refiner._roman_to_int("IX") == 9
95
+ assert refiner._roman_to_int("X") == 10
96
+ assert refiner._roman_to_int("XII") == 12
97
+ assert refiner._roman_to_int("XX") == 20
98
+
99
+ def test_int_to_word_conversion(self, refiner):
100
+ """Test integer to word conversion."""
101
+ assert refiner._int_to_word(1) == "One"
102
+ assert refiner._int_to_word(2) == "Two"
103
+ assert refiner._int_to_word(3) == "Three"
104
+ assert refiner._int_to_word(10) == "Ten"
105
+ assert refiner._int_to_word(20) == "Twenty"
106
+ assert refiner._int_to_word(25) == "25" # Falls back to digit
107
+
108
+ def test_convert_roman_numerals_with_context(self, refiner):
109
+ """Test roman numeral conversion with context words."""
110
+ test_cases = [
111
+ ("Phase I trial", "Phase One trial"),
112
+ ("Phase II study", "Phase Two study"),
113
+ ("Phase III data", "Phase Three data"),
114
+ ("Type I diabetes", "Type One diabetes"),
115
+ ("Type II error", "Type Two error"),
116
+ ("Stage IV cancer", "Stage Four cancer"),
117
+ ("Trial I results", "Trial One results"),
118
+ ]
119
+
120
+ for input_text, expected in test_cases:
121
+ result = refiner._convert_roman_numerals(input_text)
122
+ assert expected in result, f"Failed for: {input_text}"
123
+
124
+ def test_convert_standalone_roman_numerals(self, refiner):
125
+ """Test standalone roman numeral conversion."""
126
+ text = "Results for I, II, and III are positive."
127
+ result = refiner._convert_roman_numerals(text)
128
+ # Standalone roman numerals should be converted
129
+ assert "One" in result or "Two" in result or "Three" in result
130
+
131
+ def test_dont_convert_roman_in_words(self, refiner):
132
+ """Test that roman numerals inside words aren't converted."""
133
+ text = "INVALID data fromIXIN compound"
134
+ result = refiner._convert_roman_numerals(text)
135
+ # Should not break words containing I, V, X, etc.
136
+ assert "INVALID" in result or "Invalid" in result # May be case-normalized
137
+
138
+ def test_clean_special_characters(self, refiner):
139
+ """Test special character cleanup."""
140
+ # Using unicode escapes to avoid syntax issues
141
+ text = "Text with \u2014 em-dash and \u201csmart quotes\u201d and \u2018apostrophes\u2019."
142
+ result = refiner._clean_special_characters(text)
143
+ assert "\u2014" not in result # em-dash
144
+ assert "\u201c" not in result # smart quote open
145
+ assert "\u2018" not in result # smart apostrophe
146
+ assert "-" in result
147
+
148
+ def test_normalize_whitespace(self, refiner):
149
+ """Test whitespace normalization."""
150
+ text = "Text with multiple spaces\n\n\n\nand many newlines"
151
+ result = refiner._normalize_whitespace(text)
152
+ assert " " not in result # No double spaces
153
+ assert "\n\n\n" not in result # Max two newlines
154
+
155
+ async def test_full_refine_workflow(self, refiner):
156
+ """Test complete refinement workflow."""
157
+ markdown_text = """# Summary
158
+
159
+ **Metformin** shows promise for *long COVID* treatment [1].
160
+
161
+ ## Phase I Trials
162
+
163
+ Research (Smith et al., 2023) indicates [2,3]:
164
+ - 50% improvement
165
+ - Low side effects
166
+
167
+ Check [this study](https://example.com) for details.
168
+
169
+ # References
170
+ [1] Smith, J. et al. (2023)
171
+ [2] Jones, K. (2022)
172
+ """
173
+
174
+ result = await refiner.refine_for_audio(markdown_text)
175
+
176
+ # Check markdown removed
177
+ assert "#" not in result
178
+ assert "**" not in result
179
+ assert "*" not in result
180
+
181
+ # Check citations removed
182
+ assert "[1]" not in result
183
+ assert "(Smith et al., 2023)" not in result
184
+
185
+ # Check roman numerals converted
186
+ assert "Phase One" in result
187
+
188
+ # Check references section removed
189
+ assert "References" not in result
190
+ assert "Smith, J. et al." not in result
191
+
192
+ # Check content preserved
193
+ assert "Metformin" in result
194
+ assert "long COVID" in result
195
+
196
+ async def test_convenience_function(self):
197
+ """Test convenience function."""
198
+ text = "**Bold** text with [link](url)"
199
+ result = await refine_text_for_audio(text)
200
+ assert "**" not in result
201
+ assert "[link]" not in result
202
+ assert "Bold" in result
203
+
204
+ async def test_empty_text(self, refiner):
205
+ """Test handling of empty text."""
206
+ assert await refiner.refine_for_audio("") == ""
207
+ assert await refiner.refine_for_audio(" ") == ""
208
+
209
+ async def test_no_references_section(self, refiner):
210
+ """Test text without References section."""
211
+ text = "Main content without references."
212
+ result = await refiner.refine_for_audio(text)
213
+ assert "Main content without references" in result
214
+
215
+ def test_multiple_reference_formats(self, refiner):
216
+ """Test different References section formats."""
217
+ formats = [
218
+ ("# References\nContent", True), # Markdown header - will be removed
219
+ ("## References\nContent", True), # Markdown header - will be removed
220
+ ("**References**\nContent", True), # Bold heading - will be removed
221
+ ("References:\nContent", False), # Standalone without markers - NOT removed (edge case)
222
+ ]
223
+
224
+ for format_text, should_remove in formats:
225
+ text = f"Main content\n{format_text}"
226
+ result = refiner._remove_references_sections(text)
227
+ assert "Main content" in result
228
+ if should_remove:
229
+ assert "References" not in result or result.count("References") == 0
230
+ # Standalone "References:" without markers is an edge case we don't handle
231
+
232
+ def test_preserve_paragraph_structure(self, refiner):
233
+ """Test that paragraph structure is preserved."""
234
+ text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph."
235
+
236
+ result = refiner._normalize_whitespace(text)
237
+ # Should have paragraph breaks (double newlines)
238
+ assert "\n\n" in result
239
+ # But not excessive newlines
240
+ assert "\n\n\n" not in result
241
+
242
+ @patch('src.agents.audio_refiner.get_pydantic_ai_model')
243
+ async def test_llm_polish_disabled_by_default(self, mock_get_model, refiner):
244
+ """Test that LLM polish is not called by default."""
245
+ text = "Test text"
246
+ result = await refiner.refine_for_audio(text, use_llm_polish=False)
247
+
248
+ # LLM should not be called when disabled
249
+ mock_get_model.assert_not_called()
250
+ assert "Test text" in result
251
+
252
+ @patch('src.agents.audio_refiner.Agent')
253
+ @patch('src.agents.audio_refiner.get_pydantic_ai_model')
254
+ async def test_llm_polish_enabled(self, mock_get_model, mock_agent_class, refiner):
255
+ """Test that LLM polish is called when enabled."""
256
+ # Setup mock
257
+ mock_model = Mock()
258
+ mock_get_model.return_value = mock_model
259
+
260
+ mock_agent_instance = Mock()
261
+ mock_result = Mock()
262
+ mock_result.output = "Polished text"
263
+ mock_agent_instance.run = AsyncMock(return_value=mock_result)
264
+ mock_agent_class.return_value = mock_agent_instance
265
+
266
+ # Test with LLM polish enabled
267
+ text = "**Test** text"
268
+ result = await refiner.refine_for_audio(text, use_llm_polish=True)
269
+
270
+ # Verify LLM was called
271
+ mock_get_model.assert_called_once()
272
+ mock_agent_class.assert_called_once()
273
+ mock_agent_instance.run.assert_called_once()
274
+
275
+ assert result == "Polished text"
276
+
277
+ @patch('src.agents.audio_refiner.Agent')
278
+ @patch('src.agents.audio_refiner.get_pydantic_ai_model')
279
+ async def test_llm_polish_graceful_fallback(self, mock_get_model, mock_agent_class, refiner):
280
+ """Test graceful fallback when LLM polish fails."""
281
+ # Setup mock to raise exception
282
+ mock_get_model.return_value = Mock()
283
+ mock_agent_instance = Mock()
284
+ mock_agent_instance.run = AsyncMock(side_effect=Exception("API Error"))
285
+ mock_agent_class.return_value = mock_agent_instance
286
+
287
+ # Test with LLM polish enabled but failing
288
+ text = "Test text"
289
+ result = await refiner.refine_for_audio(text, use_llm_polish=True)
290
+
291
+ # Should fall back to rule-based output
292
+ assert "Test text" in result
293
+ assert result != "" # Should not be empty
294
+
295
+ async def test_convenience_function_with_llm_polish(self):
296
+ """Test convenience function with LLM polish parameter."""
297
+ with patch.object(AudioRefiner, 'refine_for_audio') as mock_refine:
298
+ mock_refine.return_value = AsyncMock(return_value="Refined text")()
299
+
300
+ # Test without LLM polish
301
+ result = await refine_text_for_audio("Test", use_llm_polish=False)
302
+ mock_refine.assert_called_with("Test", use_llm_polish=False)
303
+
304
+ # Test with LLM polish
305
+ result = await refine_text_for_audio("Test", use_llm_polish=True)
306
+ mock_refine.assert_called_with("Test", use_llm_polish=True)