Spaces:
Running
on
Zero
Running
on
Zero
| import datetime | |
| import builtins | |
| import asyncio | |
| import json | |
| import os | |
| import threading | |
| import traceback | |
| from pathlib import Path | |
| from queue import Empty, Queue | |
| from typing import Any, Callable, Dict, Iterator, Optional, Tuple, cast | |
| import numpy as np | |
| import torch | |
| from fastapi import FastAPI, WebSocket | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from starlette.websockets import WebSocketDisconnect, WebSocketState | |
| from vibevoice.modular.modeling_vibevoice_streaming_inference import ( | |
| VibeVoiceStreamingForConditionalGenerationInference, | |
| ) | |
| from vibevoice.processor.vibevoice_streaming_processor import ( | |
| VibeVoiceStreamingProcessor, | |
| ) | |
| from vibevoice.modular.streamer import AudioStreamer | |
| import copy | |
| BASE = Path(__file__).parent | |
| SAMPLE_RATE = 24_000 | |
| def get_timestamp(): | |
| timestamp = datetime.datetime.utcnow().replace( | |
| tzinfo=datetime.timezone.utc | |
| ).astimezone( | |
| datetime.timezone(datetime.timedelta(hours=8)) | |
| ).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] | |
| return timestamp | |
| class StreamingTTSService: | |
| def __init__( | |
| self, | |
| model_path: str, | |
| device: str = "cuda", | |
| inference_steps: int = 5, | |
| ) -> None: | |
| self.model_path = Path(model_path) | |
| self.inference_steps = inference_steps | |
| self.sample_rate = SAMPLE_RATE | |
| self.processor: Optional[VibeVoiceStreamingProcessor] = None | |
| self.model: Optional[VibeVoiceStreamingForConditionalGenerationInference] = None | |
| self.voice_presets: Dict[str, Path] = {} | |
| self.default_voice_key: Optional[str] = None | |
| self._voice_cache: Dict[str, Tuple[object, Path, str]] = {} | |
| if device == "mpx": | |
| print("Note: device 'mpx' detected, treating it as 'mps'.") | |
| device = "mps" | |
| if device == "mps" and not torch.backends.mps.is_available(): | |
| print("Warning: MPS not available. Falling back to CPU.") | |
| device = "cpu" | |
| self.device = device | |
| self._torch_device = torch.device(device) | |
| def load(self) -> None: | |
| print(f"[startup] Loading processor from {self.model_path}") | |
| self.processor = VibeVoiceStreamingProcessor.from_pretrained(str(self.model_path)) | |
| # Decide dtype & attention | |
| if self.device == "mps": | |
| load_dtype = torch.float32 | |
| device_map = None | |
| attn_impl_primary = "sdpa" | |
| elif self.device == "cuda": | |
| load_dtype = torch.bfloat16 | |
| device_map = 'cuda' | |
| attn_impl_primary = "flash_attention_2" | |
| else: | |
| load_dtype = torch.float32 | |
| device_map = 'cpu' | |
| attn_impl_primary = "sdpa" | |
| print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}") | |
| # Load model | |
| try: | |
| self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( | |
| str(self.model_path), | |
| torch_dtype=load_dtype, | |
| device_map=device_map, | |
| attn_implementation=attn_impl_primary, | |
| ) | |
| if self.device == "mps": | |
| self.model.to("mps") | |
| except Exception as e: | |
| if attn_impl_primary == 'flash_attention_2': | |
| print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.") | |
| self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( | |
| str(self.model_path), | |
| torch_dtype=load_dtype, | |
| device_map=self.device, | |
| attn_implementation='sdpa', | |
| ) | |
| print("Load model with SDPA successfully ") | |
| else: | |
| raise e | |
| self.model.eval() | |
| self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( | |
| self.model.model.noise_scheduler.config, | |
| algorithm_type="sde-dpmsolver++", | |
| beta_schedule="squaredcos_cap_v2", | |
| ) | |
| self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) | |
| self.voice_presets = self._load_voice_presets() | |
| preset_name = os.environ.get("VOICE_PRESET") | |
| self.default_voice_key = self._determine_voice_key(preset_name) | |
| self._ensure_voice_cached(self.default_voice_key) | |
| def _load_voice_presets(self) -> Dict[str, Path]: | |
| voices_dir = BASE.parent / "voices" / "streaming_model" | |
| if not voices_dir.exists(): | |
| raise RuntimeError(f"Voices directory not found: {voices_dir}") | |
| presets: Dict[str, Path] = {} | |
| for pt_path in voices_dir.glob("*.pt"): | |
| presets[pt_path.stem] = pt_path | |
| if not presets: | |
| raise RuntimeError(f"No voice preset (.pt) files found in {voices_dir}") | |
| print(f"[startup] Found {len(presets)} voice presets") | |
| return dict(sorted(presets.items())) | |
| def _determine_voice_key(self, name: Optional[str]) -> str: | |
| if name and name in self.voice_presets: | |
| return name | |
| default_key = "en-WHTest_man" | |
| if default_key in self.voice_presets: | |
| return default_key | |
| first_key = next(iter(self.voice_presets)) | |
| print(f"[startup] Using fallback voice preset: {first_key}") | |
| return first_key | |
| def _ensure_voice_cached(self, key: str) -> Tuple[object, Path, str]: | |
| if key not in self.voice_presets: | |
| raise RuntimeError(f"Voice preset {key!r} not found") | |
| if key not in self._voice_cache: | |
| preset_path = self.voice_presets[key] | |
| print(f"[startup] Loading voice preset {key} from {preset_path}") | |
| print(f"[startup] Loading prefilled prompt from {preset_path}") | |
| prefilled_outputs = torch.load( | |
| preset_path, | |
| map_location=self._torch_device, | |
| weights_only=False, | |
| ) | |
| self._voice_cache[key] = prefilled_outputs | |
| return self._voice_cache[key] | |
| def _get_voice_resources(self, requested_key: Optional[str]) -> Tuple[str, object, Path, str]: | |
| key = requested_key if requested_key and requested_key in self.voice_presets else self.default_voice_key | |
| if key is None: | |
| key = next(iter(self.voice_presets)) | |
| self.default_voice_key = key | |
| prefilled_outputs = self._ensure_voice_cached(key) | |
| return key, prefilled_outputs | |
| def _prepare_inputs(self, text: str, prefilled_outputs: object): | |
| if not self.processor or not self.model: | |
| raise RuntimeError("StreamingTTSService not initialized") | |
| processor_kwargs = { | |
| "text": text.strip(), | |
| "cached_prompt": prefilled_outputs, | |
| "padding": True, | |
| "return_tensors": "pt", | |
| "return_attention_mask": True, | |
| } | |
| processed = self.processor.process_input_with_cached_prompt(**processor_kwargs) | |
| prepared = { | |
| key: value.to(self._torch_device) if hasattr(value, "to") else value | |
| for key, value in processed.items() | |
| } | |
| return prepared | |
| def _run_generation( | |
| self, | |
| inputs, | |
| audio_streamer: AudioStreamer, | |
| errors, | |
| cfg_scale: float, | |
| do_sample: bool, | |
| temperature: float, | |
| top_p: float, | |
| refresh_negative: bool, | |
| prefilled_outputs, | |
| stop_event: threading.Event, | |
| ) -> None: | |
| try: | |
| self.model.generate( | |
| **inputs, | |
| max_new_tokens=None, | |
| cfg_scale=cfg_scale, | |
| tokenizer=self.processor.tokenizer, | |
| generation_config={ | |
| "do_sample": do_sample, | |
| "temperature": temperature if do_sample else 1.0, | |
| "top_p": top_p if do_sample else 1.0, | |
| }, | |
| audio_streamer=audio_streamer, | |
| stop_check_fn=stop_event.is_set, | |
| verbose=False, | |
| refresh_negative=refresh_negative, | |
| all_prefilled_outputs=copy.deepcopy(prefilled_outputs), | |
| ) | |
| except Exception as exc: # pragma: no cover - diagnostic logging | |
| errors.append(exc) | |
| traceback.print_exc() | |
| audio_streamer.end() | |
| def stream( | |
| self, | |
| text: str, | |
| cfg_scale: float = 1.5, | |
| do_sample: bool = False, | |
| temperature: float = 0.9, | |
| top_p: float = 0.9, | |
| refresh_negative: bool = True, | |
| inference_steps: Optional[int] = None, | |
| voice_key: Optional[str] = None, | |
| log_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None, | |
| stop_event: Optional[threading.Event] = None, | |
| ) -> Iterator[np.ndarray]: | |
| if not text.strip(): | |
| return | |
| text = text.replace("’", "'") | |
| selected_voice, prefilled_outputs = self._get_voice_resources(voice_key) | |
| def emit(event: str, **payload: Any) -> None: | |
| if log_callback: | |
| try: | |
| log_callback(event, **payload) | |
| except Exception as exc: | |
| print(f"[log_callback] Error while emitting {event}: {exc}") | |
| steps_to_use = self.inference_steps | |
| if inference_steps is not None: | |
| try: | |
| parsed_steps = int(inference_steps) | |
| if parsed_steps > 0: | |
| steps_to_use = parsed_steps | |
| except (TypeError, ValueError): | |
| pass | |
| if self.model: | |
| self.model.set_ddpm_inference_steps(num_steps=steps_to_use) | |
| self.inference_steps = steps_to_use | |
| inputs = self._prepare_inputs(text, prefilled_outputs) | |
| audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None) | |
| errors: list = [] | |
| stop_signal = stop_event or threading.Event() | |
| thread = threading.Thread( | |
| target=self._run_generation, | |
| kwargs={ | |
| "inputs": inputs, | |
| "audio_streamer": audio_streamer, | |
| "errors": errors, | |
| "cfg_scale": cfg_scale, | |
| "do_sample": do_sample, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "refresh_negative": refresh_negative, | |
| "prefilled_outputs": prefilled_outputs, | |
| "stop_event": stop_signal, | |
| }, | |
| daemon=True, | |
| ) | |
| thread.start() | |
| generated_samples = 0 | |
| try: | |
| stream = audio_streamer.get_stream(0) | |
| for audio_chunk in stream: | |
| if torch.is_tensor(audio_chunk): | |
| audio_chunk = audio_chunk.detach().cpu().to(torch.float32).numpy() | |
| else: | |
| audio_chunk = np.asarray(audio_chunk, dtype=np.float32) | |
| if audio_chunk.ndim > 1: | |
| audio_chunk = audio_chunk.reshape(-1) | |
| peak = np.max(np.abs(audio_chunk)) if audio_chunk.size else 0.0 | |
| if peak > 1.0: | |
| audio_chunk = audio_chunk / peak | |
| generated_samples += int(audio_chunk.size) | |
| emit( | |
| "model_progress", | |
| generated_sec=generated_samples / self.sample_rate, | |
| chunk_sec=audio_chunk.size / self.sample_rate, | |
| ) | |
| chunk_to_yield = audio_chunk.astype(np.float32, copy=False) | |
| yield chunk_to_yield | |
| finally: | |
| stop_signal.set() | |
| audio_streamer.end() | |
| thread.join() | |
| if errors: | |
| emit("generation_error", message=str(errors[0])) | |
| raise errors[0] | |
| def chunk_to_pcm16(self, chunk: np.ndarray) -> bytes: | |
| chunk = np.clip(chunk, -1.0, 1.0) | |
| pcm = (chunk * 32767.0).astype(np.int16) | |
| return pcm.tobytes() | |
| app = FastAPI() | |
| async def _startup() -> None: | |
| model_path = os.environ.get("MODEL_PATH") | |
| if not model_path: | |
| raise RuntimeError("MODEL_PATH not set in environment") | |
| device = os.environ.get("MODEL_DEVICE", "cuda") | |
| service = StreamingTTSService( | |
| model_path=model_path, | |
| device=device | |
| ) | |
| service.load() | |
| app.state.tts_service = service | |
| app.state.model_path = model_path | |
| app.state.device = device | |
| app.state.websocket_lock = asyncio.Lock() | |
| print("[startup] Model ready.") | |
| def streaming_tts(text: str, **kwargs) -> Iterator[np.ndarray]: | |
| service: StreamingTTSService = app.state.tts_service | |
| yield from service.stream(text, **kwargs) | |
| async def websocket_stream(ws: WebSocket) -> None: | |
| await ws.accept() | |
| text = ws.query_params.get("text", "") | |
| print(f"Client connected, text={text!r}") | |
| cfg_param = ws.query_params.get("cfg") | |
| steps_param = ws.query_params.get("steps") | |
| voice_param = ws.query_params.get("voice") | |
| try: | |
| cfg_scale = float(cfg_param) if cfg_param is not None else 1.5 | |
| except ValueError: | |
| cfg_scale = 1.5 | |
| if cfg_scale <= 0: | |
| cfg_scale = 1.5 | |
| try: | |
| inference_steps = int(steps_param) if steps_param is not None else None | |
| if inference_steps is not None and inference_steps <= 0: | |
| inference_steps = None | |
| except ValueError: | |
| inference_steps = None | |
| service: StreamingTTSService = app.state.tts_service | |
| lock: asyncio.Lock = app.state.websocket_lock | |
| if lock.locked(): | |
| busy_message = { | |
| "type": "log", | |
| "event": "backend_busy", | |
| "data": {"message": "Please wait for the other requests to complete."}, | |
| "timestamp": get_timestamp(), | |
| } | |
| print("Please wait for the other requests to complete.") | |
| try: | |
| await ws.send_text(json.dumps(busy_message)) | |
| except Exception: | |
| pass | |
| await ws.close(code=1013, reason="Service busy") | |
| return | |
| acquired = False | |
| try: | |
| await lock.acquire() | |
| acquired = True | |
| log_queue: "Queue[Dict[str, Any]]" = Queue() | |
| def enqueue_log(event: str, **data: Any) -> None: | |
| log_queue.put({"event": event, "data": data}) | |
| async def flush_logs() -> None: | |
| while True: | |
| try: | |
| entry = log_queue.get_nowait() | |
| except Empty: | |
| break | |
| message = { | |
| "type": "log", | |
| "event": entry.get("event"), | |
| "data": entry.get("data", {}), | |
| "timestamp": get_timestamp(), | |
| } | |
| try: | |
| await ws.send_text(json.dumps(message)) | |
| except Exception: | |
| break | |
| enqueue_log( | |
| "backend_request_received", | |
| text_length=len(text or ""), | |
| cfg_scale=cfg_scale, | |
| inference_steps=inference_steps, | |
| voice=voice_param, | |
| ) | |
| stop_signal = threading.Event() | |
| iterator = streaming_tts( | |
| text, | |
| cfg_scale=cfg_scale, | |
| inference_steps=inference_steps, | |
| voice_key=voice_param, | |
| log_callback=enqueue_log, | |
| stop_event=stop_signal, | |
| ) | |
| sentinel = object() | |
| first_ws_send_logged = False | |
| await flush_logs() | |
| try: | |
| while ws.client_state == WebSocketState.CONNECTED: | |
| await flush_logs() | |
| chunk = await asyncio.to_thread(next, iterator, sentinel) | |
| if chunk is sentinel: | |
| break | |
| chunk = cast(np.ndarray, chunk) | |
| payload = service.chunk_to_pcm16(chunk) | |
| await ws.send_bytes(payload) | |
| if not first_ws_send_logged: | |
| first_ws_send_logged = True | |
| enqueue_log("backend_first_chunk_sent") | |
| await flush_logs() | |
| except WebSocketDisconnect: | |
| print("Client disconnected (WebSocketDisconnect)") | |
| enqueue_log("client_disconnected") | |
| stop_signal.set() | |
| finally: | |
| stop_signal.set() | |
| enqueue_log("backend_stream_complete") | |
| await flush_logs() | |
| try: | |
| iterator_close = getattr(iterator, "close", None) | |
| if callable(iterator_close): | |
| iterator_close() | |
| except Exception: | |
| pass | |
| # clear the log queue | |
| while not log_queue.empty(): | |
| try: | |
| log_queue.get_nowait() | |
| except Empty: | |
| break | |
| if ws.client_state == WebSocketState.CONNECTED: | |
| await ws.close() | |
| print("WS handler exit") | |
| finally: | |
| if acquired: | |
| lock.release() | |
| def index(): | |
| return FileResponse(BASE / "index.html") | |
| def get_config(): | |
| service: StreamingTTSService = app.state.tts_service | |
| voices = sorted(service.voice_presets.keys()) | |
| return { | |
| "voices": voices, | |
| "default_voice": service.default_voice_key, | |
| } | |