yuripeyamashita
feat: update app.py
f83f0e7
import spaces
import torch
import gradio as gr
import tempfile
import os
import uuid
import scipy.io.wavfile
import time
import numpy as np
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
# import subprocess
# subprocess.run(
# "pip install flash-attn --no-build-isolation",
# env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
# shell=True,
# )
# device = "cuda" if torch.cuda.is_available() else "cpu"
# torch_dtype = torch.float16
# MODEL_NAME = "openai/whisper-large-v3-turbo"
# model = AutoModelForSpeechSeq2Seq.from_pretrained(
# MODEL_NAME, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
# )
# model.to(device)
# processor = AutoProcessor.from_pretrained(MODEL_NAME)
# tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME)
# pipe = pipeline(
# task="automatic-speech-recognition",
# model=model,
# tokenizer=tokenizer,
# feature_extractor=processor.feature_extractor,
# chunk_length_s=10,
# torch_dtype=torch_dtype,
# device=device,
# )
from omegaconf import OmegaConf
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
default_model_id = "whisper-large-v3"
def load_pipe(model_id: str):
return pipeline(
"automatic-speech-recognition",
model=model_id,
max_new_tokens=128,
chunk_length_s=30,
batch_size=8,
torch_dtype=torch_dtype,
device=device,
)
OmegaConf.register_new_resolver("load_pipe", load_pipe)
models_config = OmegaConf.to_object(OmegaConf.load("models.yaml"))
model = models_config[default_model_id]["model"]
def automatic_speech_recognition(model_id: str, dialect_id: str, audio_file: str):
generate_kwargs = {
"task": "transcribe",
"language": "Chinese",
"num_beams": 1,
"prompt_ids": torch.from_numpy(model.tokenizer.get_prompt_ids(dialect_id)).to(
device
),
}
return model(audio_file, generate_kwargs=generate_kwargs)["text"].replace(
f" {dialect_id}", ""
)
@spaces.GPU
def stream_transcribe(stream, new_chunk, dialect_id):
start_time = time.time()
try:
sr, y = new_chunk
# Convert to mono if stereo
if y.ndim > 1:
y = y.mean(axis=1)
y = y.astype(np.float32)
y /= np.max(np.abs(y))
if stream is not None:
stream = np.concatenate([stream, y])
else:
stream = y
generate_kwargs = {
"task": "transcribe",
"language": "Chinese",
"num_beams": 1,
"prompt_ids": torch.from_numpy(model.tokenizer.get_prompt_ids(dialect_id)).to(
device
)
}
# transcription = pipe({"sampling_rate": sr, "raw": stream})["text"]
transcription = model(stream, generate_kwargs=generate_kwargs)["text"]
end_time = time.time()
latency = end_time - start_time
return stream, transcription, f"{latency:.2f}"
except Exception as e:
print(f"Error during Transcription: {e}")
return stream, e, "Error"
def clear():
return ""
def clear_state():
return None
with gr.Blocks() as microphone:
with gr.Column():
gr.Markdown(
f"# Realtime Hakka Asr: \nNote: The first token takes about 5 seconds. After that, it works flawlessly.")
with gr.Row():
input_audio_microphone = gr.Audio(streaming=True)
dialect_drop_down = gr.Dropdown(
choices=[
(k, v)
for k, v in models_config[default_model_id]["dialect_mapping"].items()
],
value=list(models_config[default_model_id]["dialect_mapping"].values())[0],
label="腔調",
)
with gr.Row():
output = gr.Textbox(label="Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
with gr.Row():
clear_button = gr.Button("Clear Output")
state = gr.State()
input_audio_microphone.stream(stream_transcribe, [state, input_audio_microphone, dialect_drop_down], [
state, output, latency_textbox], time_limit=30, stream_every=2, concurrency_limit=None)
clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
gr.TabbedInterface([microphone], ["Microphone"])
demo.launch()