KhacHuy commited on
Commit
ada41e4
·
verified ·
1 Parent(s): fd67e65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -146
app.py CHANGED
@@ -1,190 +1,148 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModel, AutoTokenizer
4
- import spaces
5
  import os
6
  import tempfile
7
  from PIL import Image, ImageDraw
8
- import re # Import thư viện regular expression
9
 
10
- # --- 1. Load Model and Tokenizer (Done only once at startup) ---
11
- print("Loading model and tokenizer...")
 
 
12
  model_name = "deepseek-ai/DeepSeek-OCR"
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
- # Load the model to CPU first; it will be moved to GPU during processing
15
  model = AutoModel.from_pretrained(
16
  model_name,
17
- # _attn_implementation="flash_attention_2",
18
- attn_implementation="eager",
19
  trust_remote_code=True,
20
- use_safetensors=True,
21
  )
 
22
  model = model.eval()
23
- print("✅ Model loaded successfully.")
24
 
25
- # --- Helper function to find pre-generated result images ---
 
 
26
  def find_result_image(path):
27
  for filename in os.listdir(path):
28
  if "grounding" in filename or "result" in filename:
29
  try:
30
- image_path = os.path.join(path, filename)
31
- return Image.open(image_path)
32
- except Exception as e:
33
- print(f"Error opening result image {filename}: {e}")
34
  return None
35
 
36
- # --- 2. Main Processing Function (UPDATED for multi-bbox drawing) ---
37
- # @spaces.GPU
38
- def process_ocr_task(image, model_size, task_type, ref_text):
39
- """
40
- Processes an image with DeepSeek-OCR for all supported tasks.
41
- Now draws ALL detected bounding boxes for ANY task.
42
- """
43
- if image is None:
44
- return "Please upload an image first.", None
45
 
46
- print("🚀 Moving model to GPU...")
47
- device = "cuda" if torch.cuda.is_available() else "cpu"
48
- model = model.to(device)
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with tempfile.TemporaryDirectory() as output_path:
51
- # Build the prompt... (same as before)
52
- if task_type == "📝 Free OCR":
53
- prompt = "<image>\nFree OCR."
54
- elif task_type == "📄 Convert to Markdown":
55
- prompt = "<image>\n<|grounding|>Convert the document to markdown."
56
- elif task_type == "📈 Parse Figure":
57
- prompt = "<image>\nParse the figure."
58
- elif task_type == "🔍 Locate Object by Reference":
59
- if not ref_text or ref_text.strip() == "":
60
- raise gr.Error("For the 'Locate' task, you must provide the reference text to find!")
61
- prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
62
- else:
63
- prompt = "<image>\nFree OCR."
64
-
65
- temp_image_path = os.path.join(output_path, "temp_image.png")
66
- image.save(temp_image_path)
67
-
68
- # Configure model size... (same as before)
69
- size_configs = {
70
- "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
71
- "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
72
- "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
73
- "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
74
- "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True},
75
- }
76
- config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
77
-
78
- print(f"🏃 Running inference with prompt: {prompt}")
79
- text_result = model_gpu.infer(
80
  tokenizer,
81
  prompt=prompt,
82
- image_file=temp_image_path,
83
  output_path=output_path,
84
  base_size=config["base_size"],
85
  image_size=config["image_size"],
86
  crop_mode=config["crop_mode"],
87
  save_results=True,
88
- test_compress=True,
89
- eval_mode=True,
90
  )
91
 
92
- print(f"====\n📄 Text Result: {text_result}\n====")
93
 
94
- # --- NEW LOGIC: Always try to find and draw all bounding boxes ---
95
- result_image_pil = None
96
-
97
- # Define the pattern to find all coordinates like [[280, 15, 696, 997]]
98
  pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>")
99
- matches = list(pattern.finditer(text_result)) # Use finditer to get all matches
100
 
101
  if matches:
102
- print(f"✅ Found {len(matches)} bounding box(es). Drawing on the original image.")
103
-
104
- # Create a copy of the original image to draw on
105
- image_with_bboxes = image.copy()
106
- draw = ImageDraw.Draw(image_with_bboxes)
107
- w, h = image.size # Get original image dimensions
108
-
109
- for match in matches:
110
- # Extract coordinates as integers
111
- coords_norm = [int(c) for c in match.groups()]
112
- x1_norm, y1_norm, x2_norm, y2_norm = coords_norm
113
-
114
- # Scale the normalized coordinates (from 1000x1000 space) to the image's actual size
115
- x1 = int(x1_norm / 1000 * w)
116
- y1 = int(y1_norm / 1000 * h)
117
- x2 = int(x2_norm / 1000 * w)
118
- y2 = int(y2_norm / 1000 * h)
119
-
120
- # Draw the rectangle with a red outline, 3 pixels wide
121
- draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
122
-
123
- result_image_pil = image_with_bboxes
124
- else:
125
- # If no coordinates are found in the text, fall back to finding a pre-generated image
126
- print("⚠️ No bounding box coordinates found in text result. Falling back to search for a result image file.")
127
- result_image_pil = find_result_image(output_path)
128
-
129
- return text_result, result_image_pil
130
-
131
-
132
- # --- 3. Build the Gradio Interface (UPDATED) ---
133
  with gr.Blocks(title="🐳DeepSeek-OCR🐳", theme=gr.themes.Soft()) as demo:
134
- gr.Markdown(
135
- """
136
- # 🐳 Full Demo of DeepSeek-OCR 🐳
137
- **💡 How to use:**
138
- 1. **Upload an image** using the upload box.
139
- 2. Select a **Resolution**. `Gundam` is recommended for most documents.
140
- 3. Choose a **Task Type**:
141
- - **📝 Free OCR**: Extracts raw text from the image.
142
- - **📄 Convert to Markdown**: Converts the document into Markdown, preserving structure.
143
- - **📈 Parse Figure**: Extracts structured data from charts and figures.
144
- - **🔍 Locate Object by Reference**: Finds a specific object/text.
145
- 4. If this helpful, please give it a like! 🙏 ❤️
146
- """
147
- )
148
 
149
  with gr.Row():
150
  with gr.Column(scale=1):
151
- image_input = gr.Image(type="pil", label="🖼️ Upload Image", sources=["upload", "clipboard"])
152
- model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Gundam (Recommended)", label="⚙️ Resolution Size")
153
- task_type = gr.Dropdown(choices=["📝 Free OCR", "📄 Convert to Markdown", "📈 Parse Figure", "🔍 Locate Object by Reference"], value="📄 Convert to Markdown", label="🚀 Task Type")
154
- ref_text_input = gr.Textbox(label="📝 Reference Text (for Locate task)", placeholder="e.g., the teacher, 20-10, a red car...", visible=False)
155
- submit_btn = gr.Button("Process Image", variant="primary")
 
 
 
 
 
 
156
 
157
  with gr.Column(scale=2):
158
- output_text = gr.Textbox(label="📄 Text Result", lines=15, show_copy_button=True)
159
- output_image = gr.Image(label="🖼️ Image Result (if any)", type="pil")
160
-
161
- # --- UI Interaction Logic ---
162
- def toggle_ref_text_visibility(task):
163
- return gr.Textbox(visible=True) if task == "🔍 Locate Object by Reference" else gr.Textbox(visible=False)
164
-
165
- task_type.change(fn=toggle_ref_text_visibility, inputs=task_type, outputs=ref_text_input)
166
- submit_btn.click(fn=process_ocr_task, inputs=[image_input, model_size, task_type, ref_text_input], outputs=[output_text, output_image])
167
-
168
- # --- UPDATED Example Images and Tasks ---
169
- gr.Examples(
170
- examples=[
171
- ["doc_markdown.png", "Gundam (Recommended)", "📄 Convert to Markdown", ""],
172
- ["chart.png", "Gundam (Recommended)", "📈 Parse Figure", ""],
173
- ["teacher.jpg", "Base", "🔍 Locate Object by Reference", "the teacher"],
174
- ["math_locate.jpg", "Small", "🔍 Locate Object by Reference", "20-10"],
175
- ["receipt.jpg", "Base", "📝 Free OCR", ""],
176
- ],
177
- inputs=[image_input, model_size, task_type, ref_text_input],
178
- outputs=[output_text, output_image],
179
- fn=process_ocr_task,
180
- cache_examples=False, # Disable caching to ensure examples run every time
181
- )
182
-
183
- # --- 4. Launch the App ---
184
  if __name__ == "__main__":
185
- if not os.path.exists("examples"):
186
- os.makedirs("examples")
187
- # Make sure to have the correct image files in your "examples" folder
188
- # e.g., doc_markdown.png, chart.png, teacher.jpg, math_locate.jpg, receipt.jpg
189
-
190
- demo.queue(max_size=20).launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModel, AutoTokenizer
 
4
  import os
5
  import tempfile
6
  from PIL import Image, ImageDraw
7
+ import re
8
 
9
+ # -----------------------------------------
10
+ # 1. Load model ONCE at startup (CPU)
11
+ # -----------------------------------------
12
+ print("🔄 Loading model and tokenizer...")
13
  model_name = "deepseek-ai/DeepSeek-OCR"
14
+
15
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
16
+
17
  model = AutoModel.from_pretrained(
18
  model_name,
 
 
19
  trust_remote_code=True,
20
+ use_safetensors=True
21
  )
22
+
23
  model = model.eval()
24
+ print("✅ Model loaded successfully (CPU mode)!")
25
 
26
+ # -----------------------------------------
27
+ # Helper: find generated result images
28
+ # -----------------------------------------
29
  def find_result_image(path):
30
  for filename in os.listdir(path):
31
  if "grounding" in filename or "result" in filename:
32
  try:
33
+ return Image.open(os.path.join(path, filename))
34
+ except:
35
+ continue
 
36
  return None
37
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # -----------------------------------------
40
+ # 2. OCR main function
41
+ # -----------------------------------------
42
+ def process_ocr_task(image, model_size, task_type, ref_text):
43
 
44
+ if image is None:
45
+ return "Please upload image first.", None
46
+
47
+ print("⚙️ Running OCR (CPU mode)...")
48
+
49
+ # Create prompt
50
+ if task_type == "📝 Free OCR":
51
+ prompt = "<image>\nFree OCR."
52
+ elif task_type == "📄 Convert to Markdown":
53
+ prompt = "<image>\n<|grounding|>Convert document to markdown."
54
+ elif task_type == "📈 Parse Figure":
55
+ prompt = "<image>\nParse the figure."
56
+ elif task_type == "🔍 Locate Object by Reference":
57
+ if not ref_text.strip():
58
+ raise gr.Error("Reference text required!")
59
+ prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
60
+ else:
61
+ prompt = "<image>\nFree OCR."
62
+
63
+ # Size configs
64
+ size_configs = {
65
+ "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
66
+ "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
67
+ "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
68
+ "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
69
+ "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True},
70
+ }
71
+ config = size_configs[model_size]
72
+
73
+ # Temporary path
74
  with tempfile.TemporaryDirectory() as output_path:
75
+ img_path = os.path.join(output_path, "input.png")
76
+ image.save(img_path)
77
+
78
+ # Run model
79
+ text_result = model.infer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  tokenizer,
81
  prompt=prompt,
82
+ image_file=img_path,
83
  output_path=output_path,
84
  base_size=config["base_size"],
85
  image_size=config["image_size"],
86
  crop_mode=config["crop_mode"],
87
  save_results=True,
88
+ eval_mode=True
 
89
  )
90
 
91
+ print("📜 Output text:", text_result[:200])
92
 
93
+ # Draw bounding box if exists
 
 
 
94
  pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>")
95
+ matches = list(pattern.finditer(text_result))
96
 
97
  if matches:
98
+ result_img = image.copy()
99
+ draw = ImageDraw.Draw(result_img)
100
+ w, h = image.size
101
+
102
+ for m in matches:
103
+ x1n, y1n, x2n, y2n = map(int, m.groups())
104
+ draw.rectangle([
105
+ int(x1n/1000*w),
106
+ int(y1n/1000*h),
107
+ int(x2n/1000*w),
108
+ int(y2n/1000*h),
109
+ ], outline="red", width=3)
110
+
111
+ return text_result, result_img
112
+
113
+ return text_result, find_result_image(output_path)
114
+
115
+
116
+ # -----------------------------------------
117
+ # 3. UI Layout
118
+ # -----------------------------------------
 
 
 
 
 
 
 
 
 
 
119
  with gr.Blocks(title="🐳DeepSeek-OCR🐳", theme=gr.themes.Soft()) as demo:
120
+ gr.Markdown("## DeepSeek-OCR Demo - CPU Mode")
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  with gr.Row():
123
  with gr.Column(scale=1):
124
+ image_input = gr.Image(type="pil", label="Upload Image")
125
+ model_size = gr.Dropdown(
126
+ ["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
127
+ value="Gundam (Recommended)"
128
+ )
129
+ task_type = gr.Dropdown(
130
+ ["📝 Free OCR", "📄 Convert to Markdown", "📈 Parse Figure", "🔍 Locate Object by Reference"],
131
+ value="📄 Convert to Markdown"
132
+ )
133
+ ref_text = gr.Textbox(visible=False)
134
+ btn = gr.Button("🚀 Process")
135
 
136
  with gr.Column(scale=2):
137
+ out_text = gr.Textbox(lines=12, show_copy_button=True)
138
+ out_image = gr.Image(type="pil", label="Result")
139
+
140
+ def toggle(t):
141
+ return gr.Textbox(visible=(t == "🔍 Locate Object by Reference"))
142
+
143
+ task_type.change(toggle, task_type, ref_text)
144
+ btn.click(process_ocr_task, [image_input, model_size, task_type, ref_text], [out_text, out_image])
145
+
146
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  if __name__ == "__main__":
148
+ demo.queue().launch()