Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from typing import Optional | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login | |
| from peft import PeftModel | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline | |
| from transformers.pipelines.base import Pipeline | |
| load_dotenv() | |
| def ensure_hf_login() -> None: | |
| token = os.getenv("HF_TOKEN") | |
| if not token: | |
| print("HF_TOKEN not set; skipping Hugging Face login.") | |
| return | |
| try: | |
| login(token=token) | |
| except Exception as exc: | |
| print(f"Failed to login to Hugging Face Hub: {exc}") | |
| ensure_hf_login() | |
| LANGUAGE = "Arabic" | |
| BATCH_SIZE = 1 | |
| DEVICE = 0 if torch.cuda.is_available() else "cpu" | |
| BASE_MODEL_PATH = "openai/whisper-large-v3-turbo" | |
| LORA_PATH = "anaszil/whisper-large-v3-turbo-darija" | |
| PIPELINE: Optional[Pipeline] = None | |
| def _build_pipeline() -> Pipeline: | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| base_model = WhisperForConditionalGeneration.from_pretrained( | |
| BASE_MODEL_PATH, | |
| torch_dtype=torch_dtype, | |
| ) | |
| model = PeftModel.from_pretrained(base_model, LORA_PATH) | |
| processor = WhisperProcessor.from_pretrained( | |
| BASE_MODEL_PATH, | |
| language=LANGUAGE, | |
| task="transcribe", | |
| ) | |
| model.generation_config.language = LANGUAGE | |
| model.generation_config.task = "transcribe" | |
| model.generation_config.forced_decoder_ids = None | |
| model.eval() | |
| return pipeline( | |
| task="automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| chunk_length_s=30, | |
| device=DEVICE, | |
| ) | |
| def get_pipeline() -> Pipeline: | |
| global PIPELINE | |
| if PIPELINE is None: | |
| print("Loading Darija LoRA model...") | |
| PIPELINE = _build_pipeline() | |
| return PIPELINE | |
| def format_timestamp( | |
| seconds: Optional[float], | |
| always_include_hours: bool = False, | |
| decimal_marker: str = ".", | |
| ) -> Optional[str]: | |
| if seconds is None: | |
| return seconds | |
| milliseconds = round(seconds * 1000.0) | |
| hours = milliseconds // 3_600_000 | |
| milliseconds -= hours * 3_600_000 | |
| minutes = milliseconds // 60_000 | |
| milliseconds -= minutes * 60_000 | |
| whole_seconds = milliseconds // 1_000 | |
| milliseconds -= whole_seconds * 1_000 | |
| hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" | |
| return ( | |
| f"{hours_marker}{minutes:02d}:{whole_seconds:02d}" | |
| f"{decimal_marker}{milliseconds:03d}" | |
| ) | |
| def _run_transcription(audio_input, return_timestamps: bool) -> str: | |
| asr_pipeline = get_pipeline() | |
| outputs = asr_pipeline( | |
| audio_input, | |
| batch_size=BATCH_SIZE, | |
| generate_kwargs={"task": "transcribe", "language": LANGUAGE}, | |
| return_timestamps=return_timestamps, | |
| ) | |
| text = outputs["text"] | |
| if return_timestamps: | |
| chunks = outputs.get("chunks") or [] | |
| text = "\n".join( | |
| f"[{format_timestamp(chunk['timestamp'][0])} -> " | |
| f"{format_timestamp(chunk['timestamp'][1])}] {chunk['text']}" | |
| for chunk in chunks | |
| ) | |
| return text | |
| def transcribe(audio_input, return_timestamps: bool): | |
| if audio_input is None: | |
| return "Please provide audio input either via microphone or file upload." | |
| return _run_transcription(audio_input, return_timestamps) | |
| def process_audio(audio_input, timestamps): | |
| if audio_input is None: | |
| return "Please provide audio input.", "No audio input detected." | |
| try: | |
| transcription = transcribe(audio_input, timestamps) | |
| return transcription, "Transcription completed with the Darija LoRA model." | |
| except Exception as exc: | |
| return f"Error: {exc}", f"Transcription failed: {exc}" | |
| with gr.Blocks(title="Darija Speech Transcription") as demo: | |
| gr.Markdown("# Darija Speech Transcription Demo") | |
| gr.Markdown("Transcribe Darija audio with the fine-tuned Whisper LoRA model.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| timestamps_checkbox = gr.Checkbox( | |
| label="Return timestamps", | |
| value=False, | |
| ) | |
| audio_component = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Record or Upload Audio", | |
| ) | |
| transcribe_button = gr.Button("Transcribe", variant="primary") | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox( | |
| label="Transcription Output", | |
| lines=10, | |
| show_copy_button=True, | |
| ) | |
| status_message = gr.Markdown("") | |
| transcribe_button.click( | |
| fn=process_audio, | |
| inputs=[audio_component, timestamps_checkbox], | |
| outputs=[output_text, status_message], | |
| ) | |
| demo.launch() | |