Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,921 Bytes
22379f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import os
from typing import Optional
import gradio as gr
import spaces
import torch
from dotenv import load_dotenv
from huggingface_hub import login
from peft import PeftModel
from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline
from transformers.pipelines.base import Pipeline
load_dotenv()
def ensure_hf_login() -> None:
token = os.getenv("HF_TOKEN")
if not token:
print("HF_TOKEN not set; skipping Hugging Face login.")
return
try:
login(token=token)
except Exception as exc:
print(f"Failed to login to Hugging Face Hub: {exc}")
ensure_hf_login()
LANGUAGE = "Arabic"
BATCH_SIZE = 1
DEVICE = 0 if torch.cuda.is_available() else "cpu"
BASE_MODEL_PATH = "openai/whisper-large-v3-turbo"
LORA_PATH = "anaszil/whisper-large-v3-turbo-darija"
PIPELINE: Optional[Pipeline] = None
@spaces.GPU
def _build_pipeline() -> Pipeline:
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
base_model = WhisperForConditionalGeneration.from_pretrained(
BASE_MODEL_PATH,
torch_dtype=torch_dtype,
)
model = PeftModel.from_pretrained(base_model, LORA_PATH)
processor = WhisperProcessor.from_pretrained(
BASE_MODEL_PATH,
language=LANGUAGE,
task="transcribe",
)
model.generation_config.language = LANGUAGE
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None
model.eval()
return pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
device=DEVICE,
)
def get_pipeline() -> Pipeline:
global PIPELINE
if PIPELINE is None:
print("Loading Darija LoRA model...")
PIPELINE = _build_pipeline()
return PIPELINE
def format_timestamp(
seconds: Optional[float],
always_include_hours: bool = False,
decimal_marker: str = ".",
) -> Optional[str]:
if seconds is None:
return seconds
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
whole_seconds = milliseconds // 1_000
milliseconds -= whole_seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{whole_seconds:02d}"
f"{decimal_marker}{milliseconds:03d}"
)
@spaces.GPU
def _run_transcription(audio_input, return_timestamps: bool) -> str:
asr_pipeline = get_pipeline()
outputs = asr_pipeline(
audio_input,
batch_size=BATCH_SIZE,
generate_kwargs={"task": "transcribe", "language": LANGUAGE},
return_timestamps=return_timestamps,
)
text = outputs["text"]
if return_timestamps:
chunks = outputs.get("chunks") or []
text = "\n".join(
f"[{format_timestamp(chunk['timestamp'][0])} -> "
f"{format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
for chunk in chunks
)
return text
def transcribe(audio_input, return_timestamps: bool):
if audio_input is None:
return "Please provide audio input either via microphone or file upload."
return _run_transcription(audio_input, return_timestamps)
def process_audio(audio_input, timestamps):
if audio_input is None:
return "Please provide audio input.", "No audio input detected."
try:
transcription = transcribe(audio_input, timestamps)
return transcription, "Transcription completed with the Darija LoRA model."
except Exception as exc:
return f"Error: {exc}", f"Transcription failed: {exc}"
with gr.Blocks(title="Darija Speech Transcription") as demo:
gr.Markdown("# Darija Speech Transcription Demo")
gr.Markdown("Transcribe Darija audio with the fine-tuned Whisper LoRA model.")
with gr.Row():
with gr.Column(scale=1):
timestamps_checkbox = gr.Checkbox(
label="Return timestamps",
value=False,
)
audio_component = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Record or Upload Audio",
)
transcribe_button = gr.Button("Transcribe", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Transcription Output",
lines=10,
show_copy_button=True,
)
status_message = gr.Markdown("")
transcribe_button.click(
fn=process_audio,
inputs=[audio_component, timestamps_checkbox],
outputs=[output_text, status_message],
)
demo.launch()
|