KhacHuy commited on
Commit
e9b2b8f
·
verified ·
1 Parent(s): 69271e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -0
app.py CHANGED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModel, AutoTokenizer
4
+ import spaces
5
+ import os
6
+ import tempfile
7
+ from PIL import Image, ImageDraw
8
+ import re # Import thư viện regular expression
9
+
10
+ # --- 1. Load Model and Tokenizer (Done only once at startup) ---
11
+ print("Loading model and tokenizer...")
12
+ model_name = "deepseek-ai/DeepSeek-OCR"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
+ # Load the model to CPU first; it will be moved to GPU during processing
15
+ model = AutoModel.from_pretrained(
16
+ model_name,
17
+ _attn_implementation="flash_attention_2",
18
+ trust_remote_code=True,
19
+ use_safetensors=True,
20
+ )
21
+ model = model.eval()
22
+ print("✅ Model loaded successfully.")
23
+
24
+ # --- Helper function to find pre-generated result images ---
25
+ def find_result_image(path):
26
+ for filename in os.listdir(path):
27
+ if "grounding" in filename or "result" in filename:
28
+ try:
29
+ image_path = os.path.join(path, filename)
30
+ return Image.open(image_path)
31
+ except Exception as e:
32
+ print(f"Error opening result image {filename}: {e}")
33
+ return None
34
+
35
+ # --- 2. Main Processing Function (UPDATED for multi-bbox drawing) ---
36
+ @spaces.GPU
37
+ def process_ocr_task(image, model_size, task_type, ref_text):
38
+ """
39
+ Processes an image with DeepSeek-OCR for all supported tasks.
40
+ Now draws ALL detected bounding boxes for ANY task.
41
+ """
42
+ if image is None:
43
+ return "Please upload an image first.", None
44
+
45
+ print("🚀 Moving model to GPU...")
46
+ model_gpu = model.cuda().to(torch.bfloat16)
47
+ print("✅ Model is on GPU.")
48
+
49
+ with tempfile.TemporaryDirectory() as output_path:
50
+ # Build the prompt... (same as before)
51
+ if task_type == "📝 Free OCR":
52
+ prompt = "<image>\nFree OCR."
53
+ elif task_type == "📄 Convert to Markdown":
54
+ prompt = "<image>\n<|grounding|>Convert the document to markdown."
55
+ elif task_type == "📈 Parse Figure":
56
+ prompt = "<image>\nParse the figure."
57
+ elif task_type == "🔍 Locate Object by Reference":
58
+ if not ref_text or ref_text.strip() == "":
59
+ raise gr.Error("For the 'Locate' task, you must provide the reference text to find!")
60
+ prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
61
+ else:
62
+ prompt = "<image>\nFree OCR."
63
+
64
+ temp_image_path = os.path.join(output_path, "temp_image.png")
65
+ image.save(temp_image_path)
66
+
67
+ # Configure model size... (same as before)
68
+ size_configs = {
69
+ "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
70
+ "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
71
+ "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
72
+ "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
73
+ "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True},
74
+ }
75
+ config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
76
+
77
+ print(f"🏃 Running inference with prompt: {prompt}")
78
+ text_result = model_gpu.infer(
79
+ tokenizer,
80
+ prompt=prompt,
81
+ image_file=temp_image_path,
82
+ output_path=output_path,
83
+ base_size=config["base_size"],
84
+ image_size=config["image_size"],
85
+ crop_mode=config["crop_mode"],
86
+ save_results=True,
87
+ test_compress=True,
88
+ eval_mode=True,
89
+ )
90
+
91
+ print(f"====\n📄 Text Result: {text_result}\n====")
92
+
93
+ # --- NEW LOGIC: Always try to find and draw all bounding boxes ---
94
+ result_image_pil = None
95
+
96
+ # Define the pattern to find all coordinates like [[280, 15, 696, 997]]
97
+ pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>")
98
+ matches = list(pattern.finditer(text_result)) # Use finditer to get all matches
99
+
100
+ if matches:
101
+ print(f"✅ Found {len(matches)} bounding box(es). Drawing on the original image.")
102
+
103
+ # Create a copy of the original image to draw on
104
+ image_with_bboxes = image.copy()
105
+ draw = ImageDraw.Draw(image_with_bboxes)
106
+ w, h = image.size # Get original image dimensions
107
+
108
+ for match in matches:
109
+ # Extract coordinates as integers
110
+ coords_norm = [int(c) for c in match.groups()]
111
+ x1_norm, y1_norm, x2_norm, y2_norm = coords_norm
112
+
113
+ # Scale the normalized coordinates (from 1000x1000 space) to the image's actual size
114
+ x1 = int(x1_norm / 1000 * w)
115
+ y1 = int(y1_norm / 1000 * h)
116
+ x2 = int(x2_norm / 1000 * w)
117
+ y2 = int(y2_norm / 1000 * h)
118
+
119
+ # Draw the rectangle with a red outline, 3 pixels wide
120
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
121
+
122
+ result_image_pil = image_with_bboxes
123
+ else:
124
+ # If no coordinates are found in the text, fall back to finding a pre-generated image
125
+ print("⚠️ No bounding box coordinates found in text result. Falling back to search for a result image file.")
126
+ result_image_pil = find_result_image(output_path)
127
+
128
+ return text_result, result_image_pil
129
+
130
+
131
+ # --- 3. Build the Gradio Interface (UPDATED) ---
132
+ with gr.Blocks(title="🐳DeepSeek-OCR🐳", theme=gr.themes.Soft()) as demo:
133
+ gr.Markdown(
134
+ """
135
+ # 🐳 Full Demo of DeepSeek-OCR 🐳
136
+ **💡 How to use:**
137
+ 1. **Upload an image** using the upload box.
138
+ 2. Select a **Resolution**. `Gundam` is recommended for most documents.
139
+ 3. Choose a **Task Type**:
140
+ - **📝 Free OCR**: Extracts raw text from the image.
141
+ - **📄 Convert to Markdown**: Converts the document into Markdown, preserving structure.
142
+ - **📈 Parse Figure**: Extracts structured data from charts and figures.
143
+ - **🔍 Locate Object by Reference**: Finds a specific object/text.
144
+ 4. If this helpful, please give it a like! 🙏 ❤️
145
+ """
146
+ )
147
+
148
+ with gr.Row():
149
+ with gr.Column(scale=1):
150
+ image_input = gr.Image(type="pil", label="🖼️ Upload Image", sources=["upload", "clipboard"])
151
+ model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Gundam (Recommended)", label="⚙️ Resolution Size")
152
+ task_type = gr.Dropdown(choices=["📝 Free OCR", "📄 Convert to Markdown", "📈 Parse Figure", "🔍 Locate Object by Reference"], value="📄 Convert to Markdown", label="🚀 Task Type")
153
+ ref_text_input = gr.Textbox(label="📝 Reference Text (for Locate task)", placeholder="e.g., the teacher, 20-10, a red car...", visible=False)
154
+ submit_btn = gr.Button("Process Image", variant="primary")
155
+
156
+ with gr.Column(scale=2):
157
+ output_text = gr.Textbox(label="📄 Text Result", lines=15, show_copy_button=True)
158
+ output_image = gr.Image(label="🖼️ Image Result (if any)", type="pil")
159
+
160
+ # --- UI Interaction Logic ---
161
+ def toggle_ref_text_visibility(task):
162
+ return gr.Textbox(visible=True) if task == "🔍 Locate Object by Reference" else gr.Textbox(visible=False)
163
+
164
+ task_type.change(fn=toggle_ref_text_visibility, inputs=task_type, outputs=ref_text_input)
165
+ submit_btn.click(fn=process_ocr_task, inputs=[image_input, model_size, task_type, ref_text_input], outputs=[output_text, output_image])
166
+
167
+ # --- UPDATED Example Images and Tasks ---
168
+ gr.Examples(
169
+ examples=[
170
+ ["doc_markdown.png", "Gundam (Recommended)", "📄 Convert to Markdown", ""],
171
+ ["chart.png", "Gundam (Recommended)", "📈 Parse Figure", ""],
172
+ ["teacher.jpg", "Base", "🔍 Locate Object by Reference", "the teacher"],
173
+ ["math_locate.jpg", "Small", "🔍 Locate Object by Reference", "20-10"],
174
+ ["receipt.jpg", "Base", "📝 Free OCR", ""],
175
+ ],
176
+ inputs=[image_input, model_size, task_type, ref_text_input],
177
+ outputs=[output_text, output_image],
178
+ fn=process_ocr_task,
179
+ cache_examples=False, # Disable caching to ensure examples run every time
180
+ )
181
+
182
+ # --- 4. Launch the App ---
183
+ if __name__ == "__main__":
184
+ if not os.path.exists("examples"):
185
+ os.makedirs("examples")
186
+ # Make sure to have the correct image files in your "examples" folder
187
+ # e.g., doc_markdown.png, chart.png, teacher.jpg, math_locate.jpg, receipt.jpg
188
+
189
+ demo.queue(max_size=20).launch(share=True)