import os from typing import Optional import gradio as gr import spaces import torch from dotenv import load_dotenv from huggingface_hub import login from peft import PeftModel from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline from transformers.pipelines.base import Pipeline load_dotenv() def ensure_hf_login() -> None: token = os.getenv("HF_TOKEN") if not token: print("HF_TOKEN not set; skipping Hugging Face login.") return try: login(token=token) except Exception as exc: print(f"Failed to login to Hugging Face Hub: {exc}") ensure_hf_login() LANGUAGE = "Arabic" BATCH_SIZE = 1 DEVICE = 0 if torch.cuda.is_available() else "cpu" BASE_MODEL_PATH = "openai/whisper-large-v3-turbo" LORA_PATH = "anaszil/whisper-large-v3-turbo-darija" PIPELINE: Optional[Pipeline] = None @spaces.GPU def _build_pipeline() -> Pipeline: torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 base_model = WhisperForConditionalGeneration.from_pretrained( BASE_MODEL_PATH, torch_dtype=torch_dtype, ) model = PeftModel.from_pretrained(base_model, LORA_PATH) processor = WhisperProcessor.from_pretrained( BASE_MODEL_PATH, language=LANGUAGE, task="transcribe", ) model.generation_config.language = LANGUAGE model.generation_config.task = "transcribe" model.generation_config.forced_decoder_ids = None model.eval() return pipeline( task="automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, chunk_length_s=30, device=DEVICE, ) def get_pipeline() -> Pipeline: global PIPELINE if PIPELINE is None: print("Loading Darija LoRA model...") PIPELINE = _build_pipeline() return PIPELINE def format_timestamp( seconds: Optional[float], always_include_hours: bool = False, decimal_marker: str = ".", ) -> Optional[str]: if seconds is None: return seconds milliseconds = round(seconds * 1000.0) hours = milliseconds // 3_600_000 milliseconds -= hours * 3_600_000 minutes = milliseconds // 60_000 milliseconds -= minutes * 60_000 whole_seconds = milliseconds // 1_000 milliseconds -= whole_seconds * 1_000 hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" return ( f"{hours_marker}{minutes:02d}:{whole_seconds:02d}" f"{decimal_marker}{milliseconds:03d}" ) @spaces.GPU def _run_transcription(audio_input, return_timestamps: bool) -> str: asr_pipeline = get_pipeline() outputs = asr_pipeline( audio_input, batch_size=BATCH_SIZE, generate_kwargs={"task": "transcribe", "language": LANGUAGE}, return_timestamps=return_timestamps, ) text = outputs["text"] if return_timestamps: chunks = outputs.get("chunks") or [] text = "\n".join( f"[{format_timestamp(chunk['timestamp'][0])} -> " f"{format_timestamp(chunk['timestamp'][1])}] {chunk['text']}" for chunk in chunks ) return text def transcribe(audio_input, return_timestamps: bool): if audio_input is None: return "Please provide audio input either via microphone or file upload." return _run_transcription(audio_input, return_timestamps) def process_audio(audio_input, timestamps): if audio_input is None: return "Please provide audio input.", "No audio input detected." try: transcription = transcribe(audio_input, timestamps) return transcription, "Transcription completed with the Darija LoRA model." except Exception as exc: return f"Error: {exc}", f"Transcription failed: {exc}" with gr.Blocks(title="Darija Speech Transcription") as demo: gr.Markdown("# Darija Speech Transcription Demo") gr.Markdown("Transcribe Darija audio with the fine-tuned Whisper LoRA model.") with gr.Row(): with gr.Column(scale=1): timestamps_checkbox = gr.Checkbox( label="Return timestamps", value=False, ) audio_component = gr.Audio( sources=["microphone", "upload"], type="filepath", label="Record or Upload Audio", ) transcribe_button = gr.Button("Transcribe", variant="primary") with gr.Column(scale=1): output_text = gr.Textbox( label="Transcription Output", lines=10, show_copy_button=True, ) status_message = gr.Markdown("") transcribe_button.click( fn=process_audio, inputs=[audio_component, timestamps_checkbox], outputs=[output_text, status_message], ) demo.launch()