akhaliq HF Staff commited on
Commit
4422fa9
Β·
verified Β·
1 Parent(s): 0d29d88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -59
app.py CHANGED
@@ -4,39 +4,125 @@ import gradio as gr
4
  from docling_core.types.doc import DoclingDocument
5
  from docling_core.types.doc.document import DocTagsDocument
6
  from transformers import AutoProcessor, AutoModelForVision2Seq
7
- from transformers.image_utils import load_image
8
  from pathlib import Path
9
  import tempfile
 
10
  import subprocess
11
  import sys
12
 
13
- # Install flash attention before anything else
14
- @spaces.GPU(duration=120)
15
- def install_flash_attention():
16
- """Install flash attention if not available"""
 
 
17
  try:
18
- import flash_attn
19
- return True
20
- except ImportError:
21
- print("Installing flash-attention...")
22
- subprocess.check_call([sys.executable, "-m", "pip", "install", "flash-attn"])
23
- return True
 
 
 
 
24
 
25
- install_flash_attention()
 
 
 
26
 
27
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Load processor and model
30
- processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M")
31
- model = AutoModelForVision2Seq.from_pretrained(
32
- "ibm-granite/granite-docling-258M",
33
- dtype=torch.bfloat16,
34
- attn_implementation="flash_attention_2",
35
- ).to(DEVICE)
36
 
37
- def process_document(image, output_format="markdown"):
38
- """Process uploaded image to generate Docling document"""
 
 
 
39
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Prepare messages
41
  messages = [
42
  {
@@ -51,12 +137,29 @@ def process_document(image, output_format="markdown"):
51
  # Prepare inputs
52
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
53
  inputs = processor(text=prompt, images=[image], return_tensors="pt")
54
- inputs = inputs.to(DEVICE)
55
 
56
- # Generate outputs
 
 
 
57
  with torch.no_grad():
58
- generated_ids = model.generate(**inputs, max_new_tokens=8192)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
 
60
  prompt_length = inputs.input_ids.shape[1]
61
  trimmed_generated_ids = generated_ids[:, prompt_length:]
62
  doctags = processor.batch_decode(
@@ -64,6 +167,8 @@ def process_document(image, output_format="markdown"):
64
  skip_special_tokens=False,
65
  )[0].lstrip()
66
 
 
 
67
  # Create Docling document
68
  doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
69
  doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
@@ -87,28 +192,55 @@ def process_document(image, output_format="markdown"):
87
  return markdown_content, html_file, doctags
88
 
89
  except Exception as e:
90
- return f"Error processing document: {str(e)}", None, None
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def clear_results():
93
  """Clear all outputs"""
94
- return None, None, None
95
 
 
96
  with gr.Blocks(
97
  title="Docling Document Converter",
98
  theme=gr.themes.Soft(),
99
  css="""
100
- .header { text-align: center; margin-bottom: 2rem; }
101
- .format-selector { margin-top: 1rem; }
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  """
103
  ) as demo:
104
  gr.Markdown(
105
  """
106
- # Docling Document Converter
107
- Upload an image of a document page and convert it to structured markdown or HTML using the Granite-Docling model.
108
 
109
- <div class="header">
110
- <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a>
111
- </div>
 
 
112
  """,
113
  elem_classes="header"
114
  )
@@ -118,50 +250,67 @@ with gr.Blocks(
118
  image_input = gr.Image(
119
  label="Upload Document Image",
120
  type="pil",
121
- height=600,
122
- width=400,
123
- sources=["upload", "webcam", "clipboard"],
124
- show_share_button=True
125
  )
126
 
127
  format_choice = gr.Radio(
128
  choices=["markdown", "html", "both"],
129
  value="markdown",
130
  label="Output Format",
 
131
  elem_classes="format-selector"
132
  )
133
 
134
- process_btn = gr.Button(
135
- "Convert to Docling",
136
- variant="primary",
137
- size="lg"
138
- )
 
 
 
 
 
 
 
 
 
139
 
140
- clear_btn = gr.Button(
141
- "Clear Results",
142
- variant="secondary"
 
 
 
 
 
143
  )
144
 
145
  with gr.Column(scale=2):
146
- with gr.Tab("Markdown Output"):
147
  markdown_output = gr.Markdown(
 
148
  label="Structured Markdown",
149
  show_copy_button=True,
150
- lines=20
151
  )
152
 
153
- with gr.Tab("HTML Output"):
154
  html_output = gr.File(
155
  label="Download HTML File",
156
  file_types=[".html"],
157
- show_download_button=True
158
  )
159
 
160
- with gr.Tab("Raw DocTags"):
161
  doctags_output = gr.Textbox(
162
  label="Raw DocTags Output",
163
  lines=15,
164
- show_copy_button=True
 
 
165
  )
166
 
167
  # Event handlers
@@ -177,14 +326,28 @@ with gr.Blocks(
177
  outputs=[markdown_output, html_output, doctags_output]
178
  )
179
 
180
- # Example button
181
- gr.Examples(
182
- examples=[
183
- ["https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png"],
184
- ],
185
- inputs=[image_input],
186
- label="Try this example"
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
 
 
189
  if __name__ == "__main__":
190
  demo.launch()
 
4
  from docling_core.types.doc import DoclingDocument
5
  from docling_core.types.doc.document import DocTagsDocument
6
  from transformers import AutoProcessor, AutoModelForVision2Seq
 
7
  from pathlib import Path
8
  import tempfile
9
+ import os
10
  import subprocess
11
  import sys
12
 
13
+ # Try to install flash-attn at startup if not available
14
+ try:
15
+ import flash_attn
16
+ print("Flash attention already installed")
17
+ except ImportError:
18
+ print("Flash attention not found, attempting to install...")
19
  try:
20
+ subprocess.run(
21
+ [sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
22
+ check=True,
23
+ capture_output=True,
24
+ text=True
25
+ )
26
+ print("Flash attention installed successfully")
27
+ except subprocess.CalledProcessError as e:
28
+ print(f"Could not install flash attention: {e}")
29
+ print("Continuing without flash attention...")
30
 
31
+ # Global variables for model and processor
32
+ model = None
33
+ processor = None
34
+ model_loaded = False
35
 
36
+ def load_model():
37
+ """Load the model and processor"""
38
+ global model, processor, model_loaded
39
+
40
+ if not model_loaded:
41
+ try:
42
+ # Load processor
43
+ processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M")
44
+
45
+ # Determine device
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+
48
+ # Check if flash attention is available
49
+ attn_implementation = "eager" # default
50
+ if device == "cuda":
51
+ try:
52
+ import flash_attn
53
+ attn_implementation = "flash_attention_2"
54
+ print("Using Flash Attention 2")
55
+ except ImportError:
56
+ print("Flash attention not available, using eager attention")
57
+ attn_implementation = "eager"
58
+
59
+ # Load model with appropriate settings
60
+ print(f"Loading model on {device} with {attn_implementation}...")
61
+
62
+ if device == "cuda":
63
+ # For GPU, use bfloat16 for better performance
64
+ model = AutoModelForVision2Seq.from_pretrained(
65
+ "ibm-granite/granite-docling-258M",
66
+ torch_dtype=torch.bfloat16,
67
+ attn_implementation=attn_implementation,
68
+ device_map="auto",
69
+ trust_remote_code=True
70
+ )
71
+ else:
72
+ # For CPU, use float32
73
+ model = AutoModelForVision2Seq.from_pretrained(
74
+ "ibm-granite/granite-docling-258M",
75
+ torch_dtype=torch.float32,
76
+ attn_implementation="eager",
77
+ trust_remote_code=True
78
+ )
79
+ model = model.to(device)
80
+
81
+ model_loaded = True
82
+ print(f"Model loaded successfully on {device}")
83
+
84
+ except Exception as e:
85
+ print(f"Error loading model: {e}")
86
+ # Fallback loading without special attention
87
+ try:
88
+ processor = AutoProcessor.from_pretrained("ibm-granite/granite-docling-258M")
89
+ model = AutoModelForVision2Seq.from_pretrained(
90
+ "ibm-granite/granite-docling-258M",
91
+ torch_dtype=torch.float32,
92
+ trust_remote_code=True
93
+ )
94
+ device = "cpu"
95
+ model = model.to(device)
96
+ model_loaded = True
97
+ print("Model loaded on CPU as fallback")
98
+ except Exception as fallback_error:
99
+ print(f"Fallback loading also failed: {fallback_error}")
100
+ raise
101
 
102
+ # Load model at startup
103
+ load_model()
 
 
 
 
 
104
 
105
+ @spaces.GPU(duration=120)
106
+ def process_document_gpu(image, output_format="markdown"):
107
+ """Process uploaded image to generate Docling document - GPU version"""
108
+ global model, processor
109
+
110
  try:
111
+ # Ensure model is loaded
112
+ if not model_loaded:
113
+ load_model()
114
+
115
+ # Move model to GPU if available (for ZeroGPU)
116
+ device = "cuda" if torch.cuda.is_available() else "cpu"
117
+
118
+ # For ZeroGPU, the model might need to be moved to GPU
119
+ if device == "cuda":
120
+ # Only move if not already on cuda
121
+ if hasattr(model, 'device') and model.device.type != 'cuda':
122
+ model = model.to(device)
123
+
124
+ print(f"Processing on {device}")
125
+
126
  # Prepare messages
127
  messages = [
128
  {
 
137
  # Prepare inputs
138
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
139
  inputs = processor(text=prompt, images=[image], return_tensors="pt")
 
140
 
141
+ # Move inputs to the same device as the model
142
+ inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
143
+
144
+ # Generate outputs with memory-efficient settings
145
  with torch.no_grad():
146
+ if device == "cuda":
147
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
148
+ generated_ids = model.generate(
149
+ **inputs,
150
+ max_new_tokens=8192,
151
+ do_sample=False,
152
+ temperature=None,
153
+ top_p=None
154
+ )
155
+ else:
156
+ generated_ids = model.generate(
157
+ **inputs,
158
+ max_new_tokens=8192,
159
+ do_sample=False
160
+ )
161
 
162
+ # Process the output
163
  prompt_length = inputs.input_ids.shape[1]
164
  trimmed_generated_ids = generated_ids[:, prompt_length:]
165
  doctags = processor.batch_decode(
 
167
  skip_special_tokens=False,
168
  )[0].lstrip()
169
 
170
+ print(f"Generated {len(doctags)} characters of DocTags")
171
+
172
  # Create Docling document
173
  doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
174
  doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
 
192
  return markdown_content, html_file, doctags
193
 
194
  except Exception as e:
195
+ error_msg = f"Error processing document: {str(e)}"
196
+ print(error_msg)
197
+ import traceback
198
+ print(traceback.format_exc())
199
+ return error_msg, None, None
200
+
201
+ def process_document(image, output_format="markdown"):
202
+ """Wrapper function to handle processing"""
203
+ if image is None:
204
+ return "Please upload an image first.", None, None
205
+
206
+ # Call the GPU-decorated function
207
+ return process_document_gpu(image, output_format)
208
 
209
  def clear_results():
210
  """Clear all outputs"""
211
+ return "", None, ""
212
 
213
+ # Create Gradio interface
214
  with gr.Blocks(
215
  title="Docling Document Converter",
216
  theme=gr.themes.Soft(),
217
  css="""
218
+ .header {
219
+ text-align: center;
220
+ margin-bottom: 2rem;
221
+ }
222
+ .format-selector {
223
+ margin-top: 1rem;
224
+ }
225
+ .markdown-output {
226
+ max-height: 600px;
227
+ overflow-y: auto;
228
+ padding: 10px;
229
+ border: 1px solid #ddd;
230
+ border-radius: 5px;
231
+ background-color: #f9f9f9;
232
+ }
233
  """
234
  ) as demo:
235
  gr.Markdown(
236
  """
237
+ # πŸ“„ Docling Document Converter
 
238
 
239
+ Upload an image of a document page and convert it to structured markdown or HTML using the IBM Granite-Docling model.
240
+
241
+ This space uses ZeroGPU for efficient processing. The model converts document images into structured formats while preserving layout and formatting.
242
+
243
+ ---
244
  """,
245
  elem_classes="header"
246
  )
 
250
  image_input = gr.Image(
251
  label="Upload Document Image",
252
  type="pil",
253
+ height=400,
254
+ sources=["upload", "clipboard"],
255
+ show_label=True
 
256
  )
257
 
258
  format_choice = gr.Radio(
259
  choices=["markdown", "html", "both"],
260
  value="markdown",
261
  label="Output Format",
262
+ info="Choose the output format for the converted document",
263
  elem_classes="format-selector"
264
  )
265
 
266
+ with gr.Row():
267
+ process_btn = gr.Button(
268
+ "πŸš€ Convert Document",
269
+ variant="primary",
270
+ size="lg",
271
+ scale=2
272
+ )
273
+
274
+ clear_btn = gr.Button(
275
+ "πŸ—‘οΈ Clear",
276
+ variant="secondary",
277
+ size="lg",
278
+ scale=1
279
+ )
280
 
281
+ # Status indicator
282
+ gr.Markdown(
283
+ """
284
+ ### ℹ️ Tips:
285
+ - Upload clear, high-resolution images for best results
286
+ - The model works best with text documents, tables, and structured content
287
+ - Processing may take a few moments depending on document complexity
288
+ """
289
  )
290
 
291
  with gr.Column(scale=2):
292
+ with gr.Tab("πŸ“ Markdown Output"):
293
  markdown_output = gr.Markdown(
294
+ value="",
295
  label="Structured Markdown",
296
  show_copy_button=True,
297
+ elem_classes="markdown-output"
298
  )
299
 
300
+ with gr.Tab("🌐 HTML Output"):
301
  html_output = gr.File(
302
  label="Download HTML File",
303
  file_types=[".html"],
304
+ visible=True
305
  )
306
 
307
+ with gr.Tab("🏷️ Raw DocTags"):
308
  doctags_output = gr.Textbox(
309
  label="Raw DocTags Output",
310
  lines=15,
311
+ max_lines=30,
312
+ show_copy_button=True,
313
+ placeholder="Raw DocTags will appear here after processing..."
314
  )
315
 
316
  # Event handlers
 
326
  outputs=[markdown_output, html_output, doctags_output]
327
  )
328
 
329
+ # Examples section
330
+ with gr.Accordion("πŸ“š Example Documents", open=False):
331
+ gr.Examples(
332
+ examples=[
333
+ ["https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png"],
334
+ ],
335
+ inputs=[image_input],
336
+ label="Click to load an example document",
337
+ cache_examples=False
338
+ )
339
+
340
+ # Footer
341
+ gr.Markdown(
342
+ """
343
+ ---
344
+ <div style="text-align: center; margin-top: 2rem;">
345
+ <p>Powered by <a href="https://huggingface.co/ibm-granite/granite-docling-258M" target="_blank">IBM Granite-Docling-258M</a></p>
346
+ <p>Built with ❀️ using Gradio and Hugging Face Spaces</p>
347
+ </div>
348
+ """
349
  )
350
 
351
+ # Launch the app
352
  if __name__ == "__main__":
353
  demo.launch()