Spaces:
Sleeping
Sleeping
| import spaces | |
| import torch | |
| import gradio as gr | |
| import tempfile | |
| import os | |
| import uuid | |
| import scipy.io.wavfile | |
| import time | |
| import numpy as np | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline | |
| # import subprocess | |
| # subprocess.run( | |
| # "pip install flash-attn --no-build-isolation", | |
| # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| # shell=True, | |
| # ) | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # torch_dtype = torch.float16 | |
| # MODEL_NAME = "openai/whisper-large-v3-turbo" | |
| # model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| # MODEL_NAME, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
| # ) | |
| # model.to(device) | |
| # processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
| # tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME) | |
| # pipe = pipeline( | |
| # task="automatic-speech-recognition", | |
| # model=model, | |
| # tokenizer=tokenizer, | |
| # feature_extractor=processor.feature_extractor, | |
| # chunk_length_s=10, | |
| # torch_dtype=torch_dtype, | |
| # device=device, | |
| # ) | |
| from omegaconf import OmegaConf | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| default_model_id = "whisper-large-v3" | |
| def load_pipe(model_id: str): | |
| return pipeline( | |
| "automatic-speech-recognition", | |
| model=model_id, | |
| max_new_tokens=128, | |
| chunk_length_s=30, | |
| batch_size=8, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| ) | |
| OmegaConf.register_new_resolver("load_pipe", load_pipe) | |
| models_config = OmegaConf.to_object(OmegaConf.load("models.yaml")) | |
| model = models_config[default_model_id]["model"] | |
| def automatic_speech_recognition(model_id: str, dialect_id: str, audio_file: str): | |
| generate_kwargs = { | |
| "task": "transcribe", | |
| "language": "Chinese", | |
| "num_beams": 1, | |
| "prompt_ids": torch.from_numpy(model.tokenizer.get_prompt_ids(dialect_id)).to( | |
| device | |
| ), | |
| } | |
| return model(audio_file, generate_kwargs=generate_kwargs)["text"].replace( | |
| f" {dialect_id}", "" | |
| ) | |
| def stream_transcribe(stream, new_chunk, dialect_id): | |
| start_time = time.time() | |
| try: | |
| sr, y = new_chunk | |
| # Convert to mono if stereo | |
| if y.ndim > 1: | |
| y = y.mean(axis=1) | |
| y = y.astype(np.float32) | |
| y /= np.max(np.abs(y)) | |
| if stream is not None: | |
| stream = np.concatenate([stream, y]) | |
| else: | |
| stream = y | |
| generate_kwargs = { | |
| "task": "transcribe", | |
| "language": "Chinese", | |
| "num_beams": 1, | |
| "prompt_ids": torch.from_numpy(model.tokenizer.get_prompt_ids(dialect_id)).to( | |
| device | |
| ) | |
| } | |
| # transcription = pipe({"sampling_rate": sr, "raw": stream})["text"] | |
| transcription = model(stream, generate_kwargs=generate_kwargs)["text"] | |
| end_time = time.time() | |
| latency = end_time - start_time | |
| return stream, transcription, f"{latency:.2f}" | |
| except Exception as e: | |
| print(f"Error during Transcription: {e}") | |
| return stream, e, "Error" | |
| def clear(): | |
| return "" | |
| def clear_state(): | |
| return None | |
| with gr.Blocks() as microphone: | |
| with gr.Column(): | |
| gr.Markdown( | |
| f"# Realtime Hakka Asr: \nNote: The first token takes about 5 seconds. After that, it works flawlessly.") | |
| with gr.Row(): | |
| input_audio_microphone = gr.Audio(streaming=True) | |
| dialect_drop_down = gr.Dropdown( | |
| choices=[ | |
| (k, v) | |
| for k, v in models_config[default_model_id]["dialect_mapping"].items() | |
| ], | |
| value=list(models_config[default_model_id]["dialect_mapping"].values())[0], | |
| label="腔調", | |
| ) | |
| with gr.Row(): | |
| output = gr.Textbox(label="Transcription", value="") | |
| latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0) | |
| with gr.Row(): | |
| clear_button = gr.Button("Clear Output") | |
| state = gr.State() | |
| input_audio_microphone.stream(stream_transcribe, [state, input_audio_microphone, dialect_drop_down], [ | |
| state, output, latency_textbox], time_limit=30, stream_every=2, concurrency_limit=None) | |
| clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output]) | |
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
| gr.TabbedInterface([microphone], ["Microphone"]) | |
| demo.launch() | |