akhaliq's picture
akhaliq HF Staff
Update app.py
4422fa9 verified
import spaces
import torch
import gradio as gr
from docling_core.types.doc import DoclingDocument
from docling_core.types.doc.document import DocTagsDocument
from transformers import AutoProcessor, AutoModelForVision2Seq
from pathlib import Path
import tempfile
import os
import subprocess
import sys
# Try to install flash-attn at startup if not available
try:
import flash_attn
print("Flash attention already installed")
except ImportError:
print("Flash attention not found, attempting to install...")
try:
subprocess.run(
[sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
check=True,
capture_output=True,
text=True
)
print("Flash attention installed successfully")
except subprocess.CalledProcessError as e:
print(f"Could not install flash attention: {e}")
print("Continuing without flash attention...")
# Global variables for model and processor
model = None
processor = None
model_loaded = False
def load_model():
"""Load the model and processor"""
global model, processor, model_loaded
if not model_loaded:
try:
# Load processor
processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M")
# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Check if flash attention is available
attn_implementation = "eager" # default
if device == "cuda":
try:
import flash_attn
attn_implementation = "flash_attention_2"
print("Using Flash Attention 2")
except ImportError:
print("Flash attention not available, using eager attention")
attn_implementation = "eager"
# Load model with appropriate settings
print(f"Loading model on {device} with {attn_implementation}...")
if device == "cuda":
# For GPU, use bfloat16 for better performance
model = AutoModelForVision2Seq.from_pretrained(
"ibm-granite/granite-docling-258M",
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
device_map="auto",
trust_remote_code=True
)
else:
# For CPU, use float32
model = AutoModelForVision2Seq.from_pretrained(
"ibm-granite/granite-docling-258M",
torch_dtype=torch.float32,
attn_implementation="eager",
trust_remote_code=True
)
model = model.to(device)
model_loaded = True
print(f"Model loaded successfully on {device}")
except Exception as e:
print(f"Error loading model: {e}")
# Fallback loading without special attention
try:
processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M")
model = AutoModelForVision2Seq.from_pretrained(
"ibm-granite/granite-docling-258M",
torch_dtype=torch.float32,
trust_remote_code=True
)
device = "cpu"
model = model.to(device)
model_loaded = True
print("Model loaded on CPU as fallback")
except Exception as fallback_error:
print(f"Fallback loading also failed: {fallback_error}")
raise
# Load model at startup
load_model()
@spaces.GPU(duration=120)
def process_document_gpu(image, output_format="markdown"):
"""Process uploaded image to generate Docling document - GPU version"""
global model, processor
try:
# Ensure model is loaded
if not model_loaded:
load_model()
# Move model to GPU if available (for ZeroGPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# For ZeroGPU, the model might need to be moved to GPU
if device == "cuda":
# Only move if not already on cuda
if hasattr(model, 'device') and model.device.type != 'cuda':
model = model.to(device)
print(f"Processing on {device}")
# Prepare messages
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Convert this page to docling."}
]
},
]
# Prepare inputs
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt")
# Move inputs to the same device as the model
inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
# Generate outputs with memory-efficient settings
with torch.no_grad():
if device == "cuda":
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
generated_ids = model.generate(
**inputs,
max_new_tokens=8192,
do_sample=False,
temperature=None,
top_p=None
)
else:
generated_ids = model.generate(
**inputs,
max_new_tokens=8192,
do_sample=False
)
# Process the output
prompt_length = inputs.input_ids.shape[1]
trimmed_generated_ids = generated_ids[:, prompt_length:]
doctags = processor.batch_decode(
trimmed_generated_ids,
skip_special_tokens=False,
)[0].lstrip()
print(f"Generated {len(doctags)} characters of DocTags")
# Create Docling document
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
# Generate output based on format
if output_format == "markdown":
content = doc.export_to_markdown()
return content, None, None
elif output_format == "html":
# Create temporary file for HTML
with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False) as tmp_file:
doc.save_as_html(Path(tmp_file.name))
html_file = tmp_file.name
return None, html_file, None
else:
# Return both formats
markdown_content = doc.export_to_markdown()
with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False) as tmp_file:
doc.save_as_html(Path(tmp_file.name))
html_file = tmp_file.name
return markdown_content, html_file, doctags
except Exception as e:
error_msg = f"Error processing document: {str(e)}"
print(error_msg)
import traceback
print(traceback.format_exc())
return error_msg, None, None
def process_document(image, output_format="markdown"):
"""Wrapper function to handle processing"""
if image is None:
return "Please upload an image first.", None, None
# Call the GPU-decorated function
return process_document_gpu(image, output_format)
def clear_results():
"""Clear all outputs"""
return "", None, ""
# Create Gradio interface
with gr.Blocks(
title="Docling Document Converter",
theme=gr.themes.Soft(),
css="""
.header {
text-align: center;
margin-bottom: 2rem;
}
.format-selector {
margin-top: 1rem;
}
.markdown-output {
max-height: 600px;
overflow-y: auto;
padding: 10px;
border: 1px solid #ddd;
border-radius: 5px;
background-color: #f9f9f9;
}
"""
) as demo:
gr.Markdown(
"""
# πŸ“„ Docling Document Converter
Upload an image of a document page and convert it to structured markdown or HTML using the IBM Granite-Docling model.
This space uses ZeroGPU for efficient processing. The model converts document images into structured formats while preserving layout and formatting.
---
""",
elem_classes="header"
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Upload Document Image",
type="pil",
height=400,
sources=["upload", "clipboard"],
show_label=True
)
format_choice = gr.Radio(
choices=["markdown", "html", "both"],
value="markdown",
label="Output Format",
info="Choose the output format for the converted document",
elem_classes="format-selector"
)
with gr.Row():
process_btn = gr.Button(
"πŸš€ Convert Document",
variant="primary",
size="lg",
scale=2
)
clear_btn = gr.Button(
"πŸ—‘οΈ Clear",
variant="secondary",
size="lg",
scale=1
)
# Status indicator
gr.Markdown(
"""
### ℹ️ Tips:
- Upload clear, high-resolution images for best results
- The model works best with text documents, tables, and structured content
- Processing may take a few moments depending on document complexity
"""
)
with gr.Column(scale=2):
with gr.Tab("πŸ“ Markdown Output"):
markdown_output = gr.Markdown(
value="",
label="Structured Markdown",
show_copy_button=True,
elem_classes="markdown-output"
)
with gr.Tab("🌐 HTML Output"):
html_output = gr.File(
label="Download HTML File",
file_types=[".html"],
visible=True
)
with gr.Tab("🏷️ Raw DocTags"):
doctags_output = gr.Textbox(
label="Raw DocTags Output",
lines=15,
max_lines=30,
show_copy_button=True,
placeholder="Raw DocTags will appear here after processing..."
)
# Event handlers
process_btn.click(
fn=process_document,
inputs=[image_input, format_choice],
outputs=[markdown_output, html_output, doctags_output],
show_progress="full"
)
clear_btn.click(
fn=clear_results,
outputs=[markdown_output, html_output, doctags_output]
)
# Examples section
with gr.Accordion("πŸ“š Example Documents", open=False):
gr.Examples(
examples=[
["https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png"],
],
inputs=[image_input],
label="Click to load an example document",
cache_examples=False
)
# Footer
gr.Markdown(
"""
---
<div style="text-align: center; margin-top: 2rem;">
<p>Powered by <a href="https://huggingface.co/ibm-granite/granite-docling-258M" target="_blank">IBM Granite-Docling-258M</a></p>
<p>Built with ❀️ using Gradio and Hugging Face Spaces</p>
</div>
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch()