|
|
import spaces |
|
|
import torch |
|
|
import gradio as gr |
|
|
from docling_core.types.doc import DoclingDocument |
|
|
from docling_core.types.doc.document import DocTagsDocument |
|
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
|
from pathlib import Path |
|
|
import tempfile |
|
|
import os |
|
|
import subprocess |
|
|
import sys |
|
|
|
|
|
|
|
|
try: |
|
|
import flash_attn |
|
|
print("Flash attention already installed") |
|
|
except ImportError: |
|
|
print("Flash attention not found, attempting to install...") |
|
|
try: |
|
|
subprocess.run( |
|
|
[sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"], |
|
|
check=True, |
|
|
capture_output=True, |
|
|
text=True |
|
|
) |
|
|
print("Flash attention installed successfully") |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Could not install flash attention: {e}") |
|
|
print("Continuing without flash attention...") |
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
model_loaded = False |
|
|
|
|
|
def load_model(): |
|
|
"""Load the model and processor""" |
|
|
global model, processor, model_loaded |
|
|
|
|
|
if not model_loaded: |
|
|
try: |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
attn_implementation = "eager" |
|
|
if device == "cuda": |
|
|
try: |
|
|
import flash_attn |
|
|
attn_implementation = "flash_attention_2" |
|
|
print("Using Flash Attention 2") |
|
|
except ImportError: |
|
|
print("Flash attention not available, using eager attention") |
|
|
attn_implementation = "eager" |
|
|
|
|
|
|
|
|
print(f"Loading model on {device} with {attn_implementation}...") |
|
|
|
|
|
if device == "cuda": |
|
|
|
|
|
model = AutoModelForVision2Seq.from_pretrained( |
|
|
"ibm-granite/granite-docling-258M", |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation=attn_implementation, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
else: |
|
|
|
|
|
model = AutoModelForVision2Seq.from_pretrained( |
|
|
"ibm-granite/granite-docling-258M", |
|
|
torch_dtype=torch.float32, |
|
|
attn_implementation="eager", |
|
|
trust_remote_code=True |
|
|
) |
|
|
model = model.to(device) |
|
|
|
|
|
model_loaded = True |
|
|
print(f"Model loaded successfully on {device}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
|
|
|
try: |
|
|
processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M") |
|
|
model = AutoModelForVision2Seq.from_pretrained( |
|
|
"ibm-granite/granite-docling-258M", |
|
|
torch_dtype=torch.float32, |
|
|
trust_remote_code=True |
|
|
) |
|
|
device = "cpu" |
|
|
model = model.to(device) |
|
|
model_loaded = True |
|
|
print("Model loaded on CPU as fallback") |
|
|
except Exception as fallback_error: |
|
|
print(f"Fallback loading also failed: {fallback_error}") |
|
|
raise |
|
|
|
|
|
|
|
|
load_model() |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def process_document_gpu(image, output_format="markdown"): |
|
|
"""Process uploaded image to generate Docling document - GPU version""" |
|
|
global model, processor |
|
|
|
|
|
try: |
|
|
|
|
|
if not model_loaded: |
|
|
load_model() |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
|
|
|
if hasattr(model, 'device') and model.device.type != 'cuda': |
|
|
model = model.to(device) |
|
|
|
|
|
print(f"Processing on {device}") |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image"}, |
|
|
{"type": "text", "text": "Convert this page to docling."} |
|
|
] |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
|
inputs = processor(text=prompt, images=[image], return_tensors="pt") |
|
|
|
|
|
|
|
|
inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
if device == "cuda": |
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=8192, |
|
|
do_sample=False, |
|
|
temperature=None, |
|
|
top_p=None |
|
|
) |
|
|
else: |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=8192, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
|
|
|
prompt_length = inputs.input_ids.shape[1] |
|
|
trimmed_generated_ids = generated_ids[:, prompt_length:] |
|
|
doctags = processor.batch_decode( |
|
|
trimmed_generated_ids, |
|
|
skip_special_tokens=False, |
|
|
)[0].lstrip() |
|
|
|
|
|
print(f"Generated {len(doctags)} characters of DocTags") |
|
|
|
|
|
|
|
|
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image]) |
|
|
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") |
|
|
|
|
|
|
|
|
if output_format == "markdown": |
|
|
content = doc.export_to_markdown() |
|
|
return content, None, None |
|
|
elif output_format == "html": |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False) as tmp_file: |
|
|
doc.save_as_html(Path(tmp_file.name)) |
|
|
html_file = tmp_file.name |
|
|
return None, html_file, None |
|
|
else: |
|
|
|
|
|
markdown_content = doc.export_to_markdown() |
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False) as tmp_file: |
|
|
doc.save_as_html(Path(tmp_file.name)) |
|
|
html_file = tmp_file.name |
|
|
return markdown_content, html_file, doctags |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error processing document: {str(e)}" |
|
|
print(error_msg) |
|
|
import traceback |
|
|
print(traceback.format_exc()) |
|
|
return error_msg, None, None |
|
|
|
|
|
def process_document(image, output_format="markdown"): |
|
|
"""Wrapper function to handle processing""" |
|
|
if image is None: |
|
|
return "Please upload an image first.", None, None |
|
|
|
|
|
|
|
|
return process_document_gpu(image, output_format) |
|
|
|
|
|
def clear_results(): |
|
|
"""Clear all outputs""" |
|
|
return "", None, "" |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="Docling Document Converter", |
|
|
theme=gr.themes.Soft(), |
|
|
css=""" |
|
|
.header { |
|
|
text-align: center; |
|
|
margin-bottom: 2rem; |
|
|
} |
|
|
.format-selector { |
|
|
margin-top: 1rem; |
|
|
} |
|
|
.markdown-output { |
|
|
max-height: 600px; |
|
|
overflow-y: auto; |
|
|
padding: 10px; |
|
|
border: 1px solid #ddd; |
|
|
border-radius: 5px; |
|
|
background-color: #f9f9f9; |
|
|
} |
|
|
""" |
|
|
) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π Docling Document Converter |
|
|
|
|
|
Upload an image of a document page and convert it to structured markdown or HTML using the IBM Granite-Docling model. |
|
|
|
|
|
This space uses ZeroGPU for efficient processing. The model converts document images into structured formats while preserving layout and formatting. |
|
|
|
|
|
--- |
|
|
""", |
|
|
elem_classes="header" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image( |
|
|
label="Upload Document Image", |
|
|
type="pil", |
|
|
height=400, |
|
|
sources=["upload", "clipboard"], |
|
|
show_label=True |
|
|
) |
|
|
|
|
|
format_choice = gr.Radio( |
|
|
choices=["markdown", "html", "both"], |
|
|
value="markdown", |
|
|
label="Output Format", |
|
|
info="Choose the output format for the converted document", |
|
|
elem_classes="format-selector" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
process_btn = gr.Button( |
|
|
"π Convert Document", |
|
|
variant="primary", |
|
|
size="lg", |
|
|
scale=2 |
|
|
) |
|
|
|
|
|
clear_btn = gr.Button( |
|
|
"ποΈ Clear", |
|
|
variant="secondary", |
|
|
size="lg", |
|
|
scale=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### βΉοΈ Tips: |
|
|
- Upload clear, high-resolution images for best results |
|
|
- The model works best with text documents, tables, and structured content |
|
|
- Processing may take a few moments depending on document complexity |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
with gr.Tab("π Markdown Output"): |
|
|
markdown_output = gr.Markdown( |
|
|
value="", |
|
|
label="Structured Markdown", |
|
|
show_copy_button=True, |
|
|
elem_classes="markdown-output" |
|
|
) |
|
|
|
|
|
with gr.Tab("π HTML Output"): |
|
|
html_output = gr.File( |
|
|
label="Download HTML File", |
|
|
file_types=[".html"], |
|
|
visible=True |
|
|
) |
|
|
|
|
|
with gr.Tab("π·οΈ Raw DocTags"): |
|
|
doctags_output = gr.Textbox( |
|
|
label="Raw DocTags Output", |
|
|
lines=15, |
|
|
max_lines=30, |
|
|
show_copy_button=True, |
|
|
placeholder="Raw DocTags will appear here after processing..." |
|
|
) |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
fn=process_document, |
|
|
inputs=[image_input, format_choice], |
|
|
outputs=[markdown_output, html_output, doctags_output], |
|
|
show_progress="full" |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=clear_results, |
|
|
outputs=[markdown_output, html_output, doctags_output] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π Example Documents", open=False): |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png"], |
|
|
], |
|
|
inputs=[image_input], |
|
|
label="Click to load an example document", |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
<div style="text-align: center; margin-top: 2rem;"> |
|
|
<p>Powered by <a href="https://huggingface.co/ibm-granite/granite-docling-258M" target="_blank">IBM Granite-Docling-258M</a></p> |
|
|
<p>Built with β€οΈ using Gradio and Hugging Face Spaces</p> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |