rahul7star's picture
Update app1.py
fded6f9 verified
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import upload_file
import os
import uuid
import logging
# Model configuration
MID = "apple/FastVLM-0.5B"
IMAGE_TOKEN_INDEX = -200
HF_MODEL = "rahul7star/ImageExplain"
# Load model and tokenizer (lazy load)
tok = None
model = None
def load_model():
global tok, model
if tok is None or model is None:
print("Loading model...")
tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
if torch.cuda.is_available():
device = "cuda"
dtype = torch.float16
else:
device = "cpu"
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
MID,
torch_dtype=dtype,
device_map=device,
trust_remote_code=True,
)
print(f"Model loaded on {device.upper()} successfully!")
return tok, model
import os
import uuid
import logging
from datetime import datetime
import tempfile
from huggingface_hub import HfApi, upload_file
def upload_to_hf(video_path, summary_text):
api = HfApi()
today_str = datetime.now().strftime("%Y-%m-%d")
date_folder = f"{today_str}-APPLE-IMAGE_FOLDER"
# Unique subfolder for this upload
unique_subfolder = f"upload_{uuid.uuid4().hex[:8]}"
hf_folder = f"{date_folder}/{unique_subfolder}"
logging.info(f"Uploading files to HF folder: {hf_folder} in repo {HF_MODEL}")
# --- Upload video ---
video_filename = os.path.basename(video_path)
video_hf_path = f"{hf_folder}/{video_filename}"
upload_file(
path_or_fileobj=video_path,
path_in_repo=video_hf_path,
repo_id=HF_MODEL,
repo_type="model",
token=os.environ.get("HUGGINGFACE_HUB_TOKEN"),
)
logging.info(f"βœ… Uploaded video to HF: {video_hf_path}")
# --- Upload summary.txt ---
summary_filename = f"summary_{os.path.splitext(video_filename)[0]}.txt"
summary_file = os.path.join(tempfile.gettempdir(), summary_filename)
with open(summary_file, "w", encoding="utf-8") as f:
f.write(summary_text)
summary_hf_path = f"{hf_folder}/{summary_filename}"
upload_file(
path_or_fileobj=summary_file,
path_in_repo=summary_hf_path,
repo_id=HF_MODEL,
repo_type="model",
token=os.environ.get("HUGGINGFACE_HUB_TOKEN"),
)
logging.info(f"βœ… Uploaded summary to HF: {summary_hf_path}")
return hf_folder
def caption_image(image, custom_prompt=None):
"""Generate caption + upload image+caption to HF"""
if image is None:
return "Please upload an image first."
try:
# Save uploaded image locally (needed for upload)
temp_img = "/tmp/uploaded_image.png"
image.save(temp_img)
# Load model
tok, model = load_model()
if image.mode != "RGB":
image = image.convert("RGB")
prompt = custom_prompt if custom_prompt else "Describe this image in detail."
messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
pre, post = rendered.split("<image>", 1)
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device)
px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"]
px = px.to(model.device, dtype=model.dtype)
with torch.no_grad():
out = model.generate(
inputs=input_ids,
attention_mask=attention_mask,
images=px,
max_new_tokens=128,
do_sample=False,
)
generated_text = tok.decode(out[0], skip_special_tokens=True)
response = generated_text.split("assistant")[-1].strip() if "assistant" in generated_text else generated_text
# Upload image + caption to HF repo
upload_status = upload_to_hf(temp_img, response)
return f"{response}\n\n---\n{upload_status}"
except Exception as e:
return f"Error generating caption: {str(e)}"
# Gradio UI
with gr.Blocks(title="FastVLM Image Captioning") as demo:
gr.Markdown("# πŸ–ΌοΈ FastVLM Image Captioning")
# πŸ”— Add hyperlink here
gr.Markdown(
"### πŸ”— For **Video Analysis**, click here: "
"[Video-Analysis-AppleFastVLM-7B](https://huggingface.co/spaces/rahul7star/Video-Analysis-AppleFastVLM-7B)"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
custom_prompt = gr.Textbox(
label="Custom Prompt (Optional)",
placeholder="Leave empty for default prompt",
lines=2
)
generate_btn = gr.Button("Generate + Upload", variant="primary")
clear_btn = gr.ClearButton([image_input, custom_prompt])
with gr.Column():
output = gr.Textbox(label="Generated Caption + Upload Status", lines=8, show_copy_button=True)
generate_btn.click(caption_image, [image_input, custom_prompt], output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)