whisper-darija / app_2.py
anaszil's picture
wip
22379f7
import os
from typing import Optional
import gradio as gr
import spaces
import torch
from dotenv import load_dotenv
from huggingface_hub import login
from peft import PeftModel
from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline
from transformers.pipelines.base import Pipeline
load_dotenv()
def ensure_hf_login() -> None:
token = os.getenv("HF_TOKEN")
if not token:
print("HF_TOKEN not set; skipping Hugging Face login.")
return
try:
login(token=token)
except Exception as exc:
print(f"Failed to login to Hugging Face Hub: {exc}")
ensure_hf_login()
LANGUAGE = "Arabic"
BATCH_SIZE = 1
DEVICE = 0 if torch.cuda.is_available() else "cpu"
BASE_MODEL_PATH = "openai/whisper-large-v3-turbo"
LORA_PATH = "anaszil/whisper-large-v3-turbo-darija"
PIPELINE: Optional[Pipeline] = None
@spaces.GPU
def _build_pipeline() -> Pipeline:
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
base_model = WhisperForConditionalGeneration.from_pretrained(
BASE_MODEL_PATH,
torch_dtype=torch_dtype,
)
model = PeftModel.from_pretrained(base_model, LORA_PATH)
processor = WhisperProcessor.from_pretrained(
BASE_MODEL_PATH,
language=LANGUAGE,
task="transcribe",
)
model.generation_config.language = LANGUAGE
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None
model.eval()
return pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
device=DEVICE,
)
def get_pipeline() -> Pipeline:
global PIPELINE
if PIPELINE is None:
print("Loading Darija LoRA model...")
PIPELINE = _build_pipeline()
return PIPELINE
def format_timestamp(
seconds: Optional[float],
always_include_hours: bool = False,
decimal_marker: str = ".",
) -> Optional[str]:
if seconds is None:
return seconds
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
whole_seconds = milliseconds // 1_000
milliseconds -= whole_seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{whole_seconds:02d}"
f"{decimal_marker}{milliseconds:03d}"
)
@spaces.GPU
def _run_transcription(audio_input, return_timestamps: bool) -> str:
asr_pipeline = get_pipeline()
outputs = asr_pipeline(
audio_input,
batch_size=BATCH_SIZE,
generate_kwargs={"task": "transcribe", "language": LANGUAGE},
return_timestamps=return_timestamps,
)
text = outputs["text"]
if return_timestamps:
chunks = outputs.get("chunks") or []
text = "\n".join(
f"[{format_timestamp(chunk['timestamp'][0])} -> "
f"{format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
for chunk in chunks
)
return text
def transcribe(audio_input, return_timestamps: bool):
if audio_input is None:
return "Please provide audio input either via microphone or file upload."
return _run_transcription(audio_input, return_timestamps)
def process_audio(audio_input, timestamps):
if audio_input is None:
return "Please provide audio input.", "No audio input detected."
try:
transcription = transcribe(audio_input, timestamps)
return transcription, "Transcription completed with the Darija LoRA model."
except Exception as exc:
return f"Error: {exc}", f"Transcription failed: {exc}"
with gr.Blocks(title="Darija Speech Transcription") as demo:
gr.Markdown("# Darija Speech Transcription Demo")
gr.Markdown("Transcribe Darija audio with the fine-tuned Whisper LoRA model.")
with gr.Row():
with gr.Column(scale=1):
timestamps_checkbox = gr.Checkbox(
label="Return timestamps",
value=False,
)
audio_component = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Record or Upload Audio",
)
transcribe_button = gr.Button("Transcribe", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Transcription Output",
lines=10,
show_copy_button=True,
)
status_message = gr.Markdown("")
transcribe_button.click(
fn=process_audio,
inputs=[audio_component, timestamps_checkbox],
outputs=[output_text, status_message],
)
demo.launch()