File size: 3,838 Bytes
88a5b8d 875f054 4aff560 88a5b8d 4aff560 88a5b8d 4aff560 875f054 20b327e 875f054 4aff560 875f054 4aff560 875f054 02b26ea 875f054 8178d81 dc798af 875f054 dc798af 875f054 8178d81 875f054 88a5b8d dc798af 840e373 dc798af 88a5b8d 4aff560 875f054 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import torch
from transformers import (
AutoModelForImageTextToText,
AutoProcessor,
TextIteratorStreamer,
)
from peft import PeftModel
from transformers.image_utils import load_image
from threading import Thread
import time
import html
def progress_bar_html(label: str) -> str:
"""
Returns an HTML snippet for a thin progress bar with a label.
The progress bar is styled as a dark animated bar.
"""
return f"""
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
"""
model_name = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
model = AutoModelForImageTextToText.from_pretrained(
model_name, dtype=torch.float32, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(model_name)
print(f"Successfully load the model: {model}")
def model_inference(input_dict, history):
text = input_dict["text"]
files = input_dict["files"]
if len(files) > 1:
images = [load_image(image) for image in files]
elif len(files) == 1:
images = [load_image(files[0])]
else:
images = []
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
return
if text == "" and images:
gr.Error("Please input a text query along with the image(s).")
return
messages = [
{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
],
}
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=model.dtype)
streamer = TextIteratorStreamer(
processor, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
# start timer just before generation begins
start_time = time.time()
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = "Baseline Model Response: "
yield progress_bar_html("Processing...")
for new_text in streamer:
escaped_new_text = html.escape(new_text)
buffer += escaped_new_text
time.sleep(0.001)
yield buffer
# Ensure generation thread has finished and measure elapsed time
thread.join()
elapsed = time.time() - start_time
elapsed_text = f"\nBaseline Generation Time: {elapsed:.2f} s"
buffer += html.escape(elapsed_text)
yield buffer
examples = [
[
{
"text": "Write a descriptive caption for this image in a formal tone.",
"files": ["example_images/example.png"],
}
],
[
{
"text": "What are the characters wearing?",
"files": ["example_images/example.png"],
}
],
]
demo = gr.ChatInterface(
fn=model_inference,
description="# **Smolvlm2-500M-illustration-description** \n (running on CPU) The model only sees the last input, it ignores the previous conversation history.",
examples=examples,
fill_height=True,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"]),
stop_btn="Stop Generation",
multimodal=True,
cache_examples=False,
)
demo.launch(debug=True)
|