make789 commited on
Commit
7f46686
·
verified ·
1 Parent(s): 0e7bad9

Upload ocr_service.py

Browse files
Files changed (1) hide show
  1. ocr_service.py +561 -101
ocr_service.py CHANGED
@@ -2,19 +2,59 @@ import asyncio
2
  import json
3
  import math
4
  import os
 
5
  import secrets
 
6
  from collections import defaultdict, deque
 
7
  from time import monotonic
8
- from typing import Any, Deque, DefaultDict
9
 
10
- import cv2
11
  import numpy as np
12
  from fastapi import Depends, FastAPI, Form, HTTPException, Request, UploadFile, status
 
13
  from fastapi.security import APIKeyHeader
14
- from paddleocr import PaddleOCR
15
- from paddleocr.ppstructure import PPStructure
16
 
17
- from .llm_splitter import call_llm_splitter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  ALLOWED_CONTENT_TYPES = {
20
  "image/jpeg",
@@ -24,11 +64,9 @@ ALLOWED_CONTENT_TYPES = {
24
  MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", str(5 * 1024 * 1024)))
25
  RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "30"))
26
  RATE_LIMIT_WINDOW_SECONDS = float(os.getenv("RATE_LIMIT_WINDOW_SECONDS", "60"))
27
- SERVICE_API_KEY = os.getenv("SERVICE_API_KEY")
28
- if not SERVICE_API_KEY:
29
- raise RuntimeError(
30
- "SERVICE_API_KEY environment variable must be set before starting the OCR service"
31
- )
32
  API_KEY_HEADER_NAME = "X-API-Key"
33
  MAX_CHILD_LINES = 500
34
  MAX_JSON_DEPTH = 4
@@ -36,9 +74,350 @@ MAX_JSON_STRING_LENGTH = 512
36
  MAX_JSON_DICT_KEYS = 50
37
  MAX_JSON_LIST_ITEMS = 100
38
 
39
- app = FastAPI()
40
- ocr = PaddleOCR(use_angle_cls=True, lang="en")
41
- pp = PPStructure(show_log=False, layout=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
44
  _rate_limit_lock = asyncio.Lock()
@@ -46,11 +425,22 @@ _request_log: DefaultDict[str, Deque[float]] = defaultdict(deque)
46
 
47
 
48
  def ensure_upload_is_safe(file: UploadFile) -> None:
 
49
  content_type = (file.content_type or "").lower()
50
- if content_type not in ALLOWED_CONTENT_TYPES:
 
 
 
 
 
 
 
 
 
 
51
  raise HTTPException(
52
  status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
53
- detail="Unsupported file type",
54
  )
55
 
56
  file.file.seek(0, os.SEEK_END)
@@ -63,7 +453,11 @@ def ensure_upload_is_safe(file: UploadFile) -> None:
63
  )
64
 
65
 
66
- async def verify_api_key(api_key: str | None = Depends(api_key_header)) -> str:
 
 
 
 
67
  if not api_key or not secrets.compare_digest(api_key, SERVICE_API_KEY):
68
  raise HTTPException(
69
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -92,25 +486,35 @@ async def enforce_rate_limit(
92
 
93
 
94
  def _decode_image(file: UploadFile):
 
95
  data = file.file.read()
96
  if not data:
97
  raise HTTPException(
98
  status_code=status.HTTP_400_BAD_REQUEST,
99
  detail="Uploaded file is empty",
100
  )
101
- img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR)
102
- if img is None:
 
 
 
 
 
 
 
 
 
103
  raise HTTPException(
104
  status_code=status.HTTP_400_BAD_REQUEST,
105
- detail="Unable to decode image",
106
  )
107
- return img
108
 
109
 
110
- def load_img(file: UploadFile):
111
  ensure_upload_is_safe(file)
112
  file.file.seek(0)
113
- return _decode_image(file)
 
114
 
115
 
116
  def _parse_json_field(name: str, raw: str, expected_type: type) -> Any:
@@ -274,20 +678,65 @@ def _parse_rules(raw: str) -> list:
274
  return rules
275
 
276
 
 
 
 
 
 
 
 
 
 
 
277
  @app.post("/ocr")
 
278
  async def ocr_page(
279
  file: UploadFile,
280
  _: None = Depends(enforce_rate_limit),
281
  ):
282
- img = load_img(file)
283
- res = ocr.ocr(img, cls=True)
284
- lines = []
285
- full = []
286
- for line in (res[0] or []):
287
- (x1, y1, x2, y2), (txt, conf) = line
288
- lines.append({"bbox": [x1, y1, x2, y2], "text": txt, "conf": float(conf)})
289
- full.append(txt)
290
- return {"text": "\n".join(full), "lines": lines}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
 
293
  @app.post("/split")
@@ -300,74 +749,85 @@ async def split(
300
  rules: str = Form("[]"),
301
  _: None = Depends(enforce_rate_limit),
302
  ):
303
- img = load_img(file)
304
- height, width = img.shape[:2]
305
- parent_box = _parse_parent_bbox(parent_bbox, width, height)
306
- x1, y1, x2, y2 = parent_box
307
- x1_i, y1_i, x2_i, y2_i = [int(round(v)) for v in parent_box]
308
- crop = img[y1_i:y2_i, x1_i:x2_i]
309
- if crop.size == 0:
310
- raise HTTPException(
311
- status_code=status.HTTP_400_BAD_REQUEST,
312
- detail="parent_bbox region is empty",
313
- )
314
- blocks = pp(crop)
315
-
316
- child_lines = []
317
- for b in blocks:
318
- bx1, by1, bx2, by2 = b["bbox"]
319
- sub = crop[by1:by2, bx1:bx2]
320
- det = ocr.ocr(sub, cls=True)
321
- for ln in (det[0] or []):
322
- (lx1, ly1, lx2, ly2), (txt, conf) = ln
323
- child_lines.append(
324
- {
325
- "bbox": [lx1 + bx1 + x1, ly1 + by1 + y1, lx2 + bx1 + x1, ly2 + by1 + y1],
326
- "text": txt,
327
- "conf": float(conf),
328
- "blockType": b.get("type", "text"),
329
- }
330
- )
331
- if len(child_lines) >= MAX_CHILD_LINES:
332
- break
333
- if len(child_lines) >= MAX_CHILD_LINES:
334
- break
335
-
336
- sanitized_splitter = _sanitize_label("splitter", splitter)
337
- sanitized_schema = _sanitize_label("schemaType", schemaType)
338
- parsed_settings = _parse_settings(settings)
339
- parsed_rules = _parse_rules(rules)
340
-
341
- raw_text = "\n".join([l["text"] for l in child_lines])
342
- text_truncated = False
343
- if len(raw_text) > 5000:
344
- raw_text = raw_text[:5000]
345
- text_truncated = True
346
-
347
- llm_input = {
348
- "schemaType": sanitized_schema,
349
- "splitter": sanitized_splitter,
350
- "page": {"width": width, "height": height},
351
- "parentBox": parent_box,
352
- "rawText": raw_text,
353
- "ocrLines": child_lines,
354
- "rawTextTruncated": text_truncated,
355
- "ocrLinesTruncated": len(child_lines) >= MAX_CHILD_LINES,
356
- "settings": parsed_settings,
357
- "rules": parsed_rules,
358
- }
359
-
360
  try:
361
- llm_result = await call_llm_splitter(llm_input)
362
- except ValueError as exc:
363
- raise HTTPException(
364
- status_code=status.HTTP_502_BAD_GATEWAY,
365
- detail=str(exc),
366
- ) from exc
367
- return llm_result
368
-
369
-
370
- if __name__ == "__main__":
371
- import uvicorn
372
-
373
- uvicorn.run(app, host="0.0.0.0", port=8080)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import json
3
  import math
4
  import os
5
+ import platform
6
  import secrets
7
+ import tempfile
8
  from collections import defaultdict, deque
9
+ from pathlib import Path
10
  from time import monotonic
11
+ from typing import Any, Deque, DefaultDict, Optional
12
 
 
13
  import numpy as np
14
  from fastapi import Depends, FastAPI, Form, HTTPException, Request, UploadFile, status
15
+ from fastapi.middleware.cors import CORSMiddleware
16
  from fastapi.security import APIKeyHeader
17
+ from PIL import Image
 
18
 
19
+ # Lazy import DeepSeek-OCR dependencies (only load when needed)
20
+ _torch = None
21
+ _transformers = None
22
+
23
+ def _get_torch():
24
+ global _torch
25
+ if _torch is None:
26
+ try:
27
+ import torch
28
+ _torch = torch
29
+ except ImportError:
30
+ raise RuntimeError(
31
+ "torch is not installed. Install with: pip install torch"
32
+ )
33
+ return _torch
34
+
35
+ def _get_transformers():
36
+ global _transformers
37
+ if _transformers is None:
38
+ try:
39
+ from transformers import AutoModel, AutoTokenizer
40
+ _transformers = (AutoModel, AutoTokenizer)
41
+ except ImportError:
42
+ raise RuntimeError(
43
+ "transformers is not installed. Install with: pip install transformers"
44
+ )
45
+ return _transformers
46
+
47
+ # Import llm_splitter (works as module or direct import)
48
+ try:
49
+ from llm_splitter import call_llm_splitter
50
+ except ImportError:
51
+ # Fallback for relative import
52
+ try:
53
+ from .llm_splitter import call_llm_splitter
54
+ except ImportError:
55
+ # If llm_splitter doesn't exist, define a stub
56
+ async def call_llm_splitter(*args, **kwargs):
57
+ raise NotImplementedError("llm_splitter not available")
58
 
59
  ALLOWED_CONTENT_TYPES = {
60
  "image/jpeg",
 
64
  MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", str(5 * 1024 * 1024)))
65
  RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "30"))
66
  RATE_LIMIT_WINDOW_SECONDS = float(os.getenv("RATE_LIMIT_WINDOW_SECONDS", "60"))
67
+ # Allow API key to be optional for development (security risk in production!)
68
+ SERVICE_API_KEY = os.getenv("SERVICE_API_KEY", "dev-key-change-in-production")
69
+ REQUIRE_API_KEY = os.getenv("REQUIRE_API_KEY", "false").lower() == "true"
 
 
70
  API_KEY_HEADER_NAME = "X-API-Key"
71
  MAX_CHILD_LINES = 500
72
  MAX_JSON_DEPTH = 4
 
74
  MAX_JSON_DICT_KEYS = 50
75
  MAX_JSON_LIST_ITEMS = 100
76
 
77
+ # DeepSeek-OCR Model Configuration - Maximum Quality Settings for CPU/Spaces
78
+ MODEL_NAME = "deepseek-ai/DeepSeek-OCR"
79
+ # PIN MODEL REVISION to prevent auto-updates that break compatibility
80
+ MODEL_REVISION = os.getenv("DEEPSEEK_MODEL_REVISION", "2c968b433af61a059311cbf8997765023806a24d")
81
+
82
+ # Detect Apple Silicon (M1/M2/M3/M4) - use MPS if available, otherwise CPU
83
+ IS_APPLE_SILICON = platform.machine() == "arm64"
84
+ USE_GPU = os.getenv("USE_GPU", "true").lower() == "true" and not IS_APPLE_SILICON
85
+ USE_MPS = IS_APPLE_SILICON
86
+ # Quality settings - Gundam preset recommended for CPU/Spaces
87
+ BASE_SIZE = int(os.getenv("DEEPSEEK_BASE_SIZE", "1024"))
88
+ IMAGE_SIZE = int(os.getenv("DEEPSEEK_IMAGE_SIZE", "640"))
89
+ CROP_MODE = os.getenv("DEEPSEEK_CROP_MODE", "true").lower() == "true"
90
+
91
+ app = FastAPI(
92
+ title="DeepSeek-OCR API",
93
+ description="OCR Service using DeepSeek-OCR for maximum quality text extraction",
94
+ version="1.0.0"
95
+ )
96
+
97
+ # Add root endpoint for health check (compatible with HuggingFace Spaces)
98
+ @app.get("/")
99
+ async def root(__sign: Optional[str] = None):
100
+ """
101
+ Root endpoint - compatible with HuggingFace Spaces authentication.
102
+ The __sign parameter is used by HuggingFace's proxy but can be ignored.
103
+ """
104
+ return {
105
+ "service": "DeepSeek-OCR API",
106
+ "status": "running",
107
+ "version": "1.0.0",
108
+ "endpoints": {
109
+ "docs": "/docs",
110
+ "ocr": "/ocr",
111
+ "split": "/split"
112
+ }
113
+ }
114
+
115
+ # Add CORS middleware to allow frontend requests
116
+ app.add_middleware(
117
+ CORSMiddleware,
118
+ allow_origins=["*"], # In production, replace with specific origins
119
+ allow_credentials=True,
120
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
121
+ allow_headers=["*"],
122
+ expose_headers=["*"],
123
+ )
124
+
125
+ # Initialize DeepSeek-OCR model
126
+ _ocr_model = None
127
+ _ocr_tokenizer = None
128
+ _model_lock = asyncio.Lock()
129
+
130
+
131
+ def _download_and_patch_model_locally(model_id: str, revision: str) -> str:
132
+ """
133
+ Download DeepSeek-OCR to a local dir, patch for CPU:
134
+ - remove hardcoded .cuda()
135
+ - force float32 (strip .bfloat16() / .to(torch.bfloat16))
136
+ - disable torch.autocast("cuda", ...) blocks
137
+
138
+ Return local path for from_pretrained(...).
139
+
140
+ Per official HuggingFace discussions:
141
+ - https://huggingface.co/deepseek-ai/DeepSeek-OCR/discussions/21 (CPU inference)
142
+ - https://huggingface.co/deepseek-ai/DeepSeek-OCR/discussions/20 (BF16/FP32 issues)
143
+ """
144
+ import re
145
+
146
+ try:
147
+ from huggingface_hub import snapshot_download
148
+ except ImportError:
149
+ raise RuntimeError("huggingface_hub is required. Install with: pip install huggingface_hub")
150
+
151
+ print(f" 📥 Downloading model (revision {revision[:8]})...")
152
+ local_dir = snapshot_download(model_id, revision=revision)
153
+ print(f" ✅ Downloaded to: {local_dir}")
154
+ local_dir = Path(local_dir)
155
+
156
+ def ensure_imports(src: str) -> str:
157
+ """Ensure torch and contextlib are imported"""
158
+ if "import torch" not in src and "from torch" not in src:
159
+ src = "import torch\n" + src
160
+ if "import contextlib" not in src:
161
+ # Add after torch import
162
+ if "import torch" in src:
163
+ src = src.replace("import torch", "import torch\nimport contextlib")
164
+ elif "from torch" in src:
165
+ src = src.replace("from torch", "import contextlib\nfrom torch")
166
+ else:
167
+ # Add at top if no torch import
168
+ lines = src.split('\n')
169
+ for i, line in enumerate(lines):
170
+ if line.startswith("import ") or line.startswith("from "):
171
+ continue
172
+ else:
173
+ lines.insert(i, "import contextlib")
174
+ break
175
+ src = '\n'.join(lines)
176
+ return src
177
+
178
+ def patch_text(src: str) -> str:
179
+ """Patch text to remove CUDA/BF16 and force FP32"""
180
+ # A) Remove hardcoded CUDA device moves
181
+ src = src.replace(".unsqueeze(-1).cuda()", ".unsqueeze(-1)")
182
+ src = src.replace("input_ids.unsqueeze(0).cuda()", "input_ids.unsqueeze(0)")
183
+ src = src.replace("(images_crop.cuda(), images_ori.cuda())", "(images_crop, images_ori)")
184
+ src = src.replace("images_seq_mask = images_seq_mask.unsqueeze(0).cuda()",
185
+ "images_seq_mask = images_seq_mask.unsqueeze(0)")
186
+ src = src.replace("input_ids.unsqueeze(0).cuda().shape[1]",
187
+ "input_ids.unsqueeze(0).shape[1]")
188
+
189
+ # B) Force FP32 (strip BF16 casts)
190
+ src = re.sub(r"\.bfloat16\(\)", ".float()", src)
191
+ src = re.sub(r"\.to\(\s*torch\.bfloat16\s*\)", ".to(torch.float32)", src)
192
+ src = re.sub(r"\.to\(\s*dtype\s*=\s*torch\.bfloat16\s*\)", ".to(dtype=torch.float32)", src)
193
+
194
+ # C) Disable CUDA autocast blocks
195
+ src = ensure_imports(src)
196
+ # Match both torch.autocast("cuda", ...) and torch.autocast(device_type="cuda", ...)
197
+ src = re.sub(r'torch\.autocast\(\s*["\']cuda["\']\s*,[^)]*\)',
198
+ "contextlib.nullcontext()", src)
199
+ src = re.sub(r'torch\.autocast\(\s*device_type\s*=\s*["\']cuda["\'][^)]*\)',
200
+ "contextlib.nullcontext()", src)
201
+
202
+ return src
203
+
204
+ # Patch both files where they may appear
205
+ targets = list(local_dir.rglob("modeling_deepseekocr.py")) + \
206
+ list(local_dir.rglob("deepencoder.py"))
207
+
208
+ if not targets:
209
+ raise RuntimeError("Could not locate DeepSeek-OCR source files to patch")
210
+
211
+ for p in targets:
212
+ print(f" 🔍 Found file: {p.name}")
213
+ txt = p.read_text(encoding="utf-8")
214
+ new = patch_text(txt)
215
+
216
+ if new != txt:
217
+ p.write_text(new, encoding="utf-8")
218
+ print(f" ✅ Patched for CPU+FP32: {p.name}")
219
+ else:
220
+ print(f" ℹ️ Already CPU-patched: {p.name}")
221
+
222
+ return str(local_dir)
223
+
224
+ async def get_ocr_model():
225
+ """Lazy load DeepSeek-OCR model with compatibility patching"""
226
+ global _ocr_model, _ocr_tokenizer
227
+ if _ocr_model is None or _ocr_tokenizer is None:
228
+ async with _model_lock:
229
+ if _ocr_model is None or _ocr_tokenizer is None:
230
+ # Lazy import dependencies
231
+ AutoModel, AutoTokenizer = _get_transformers()
232
+ torch = _get_torch()
233
+
234
+ print(f"Loading DeepSeek-OCR model (MAXIMUM QUALITY): {MODEL_NAME}")
235
+ print(f" - Base size: {BASE_SIZE}")
236
+ print(f" - Image size: {IMAGE_SIZE}")
237
+ print(f" - Crop mode: {CROP_MODE}")
238
+
239
+ # 1) Download & patch; 2) Load from local dir so our patch is used
240
+ local_dir = _download_and_patch_model_locally(MODEL_NAME, MODEL_REVISION)
241
+
242
+ print(" - Loading tokenizer (local, pinned revision)...")
243
+ _ocr_tokenizer = AutoTokenizer.from_pretrained(
244
+ local_dir,
245
+ trust_remote_code=True,
246
+ local_files_only=True # Load from local patched directory
247
+ )
248
+ print(" - Tokenizer loaded successfully")
249
+
250
+ # Fix pad_token_id warning
251
+ if _ocr_tokenizer.pad_token_id is None:
252
+ _ocr_tokenizer.pad_token = _ocr_tokenizer.eos_token or _ocr_tokenizer.unk_token
253
+
254
+ # Load model with compatibility settings
255
+ load_kwargs = {
256
+ "trust_remote_code": True,
257
+ "use_safetensors": True,
258
+ "attn_implementation": "eager", # SDPA not supported by this arch
259
+ }
260
+
261
+ # Load from patched local directory
262
+ _ocr_model = AutoModel.from_pretrained(
263
+ local_dir,
264
+ local_files_only=True, # Load from local patched directory
265
+ **load_kwargs
266
+ ).eval()
267
+
268
+ # Handle device placement (force FP32 on CPU/MPS)
269
+ if USE_MPS and torch.backends.mps.is_available():
270
+ _ocr_model = _ocr_model.to("mps").to(dtype=torch.float32)
271
+ print(" - DeepSeek-OCR on MPS (float32)")
272
+ elif USE_GPU and torch.cuda.is_available():
273
+ _ocr_model = _ocr_model.cuda().to(torch.bfloat16)
274
+ print(" - DeepSeek-OCR on CUDA (bf16)")
275
+ else:
276
+ _ocr_model = _ocr_model.to(dtype=torch.float32)
277
+ print(" - DeepSeek-OCR on CPU (float32)")
278
+ return _ocr_model, _ocr_tokenizer
279
+
280
+
281
+ async def run_deepseek_ocr(
282
+ image_path: str,
283
+ prompt: str = "<image>\n<|grounding|>Convert the document to markdown with preserved layout.",
284
+ use_grounding: bool = True
285
+ ) -> dict:
286
+ """
287
+ Run DeepSeek-OCR on an image file with advanced grounding support.
288
+ """
289
+ model, tokenizer = await get_ocr_model()
290
+
291
+ output_path = tempfile.mkdtemp()
292
+
293
+ try:
294
+ # OCR quality settings - Gundam preset recommended for CPU/Spaces
295
+ torch = _get_torch()
296
+ if USE_GPU and torch.cuda.is_available():
297
+ # GPU: Use maximum quality (Large preset)
298
+ actual_base_size = BASE_SIZE
299
+ actual_image_size = IMAGE_SIZE
300
+ else:
301
+ # CPU/Spaces: Use Gundam preset (recommended for CPU to avoid OOM)
302
+ actual_base_size = 1024
303
+ actual_image_size = 640
304
+ print(f" - Using CPU-optimized quality: base_size={actual_base_size}, image_size={actual_image_size}")
305
+
306
+ # Use torch.inference_mode() to reduce overhead on CPU
307
+ torch = _get_torch()
308
+ with torch.inference_mode():
309
+ result = model.infer(
310
+ tokenizer,
311
+ prompt=prompt,
312
+ image_file=image_path,
313
+ output_path=output_path,
314
+ base_size=actual_base_size,
315
+ image_size=actual_image_size,
316
+ crop_mode=CROP_MODE,
317
+ save_results=False,
318
+ test_compress=False,
319
+ )
320
+
321
+ # Parse result - DeepSeek-OCR returns structured markdown output
322
+ ocr_text = result if isinstance(result, str) else str(result)
323
+
324
+ # Extract structured lines from markdown
325
+ lines = _parse_deepseek_output(ocr_text)
326
+
327
+ return {
328
+ "text": ocr_text,
329
+ "lines": lines,
330
+ }
331
+ except Exception as e:
332
+ print(f"DeepSeek-OCR error: {e}")
333
+ import traceback
334
+ traceback.print_exc()
335
+ raise HTTPException(
336
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
337
+ detail=f"OCR processing failed: {str(e)}",
338
+ )
339
+ finally:
340
+ # Cleanup temp directory
341
+ try:
342
+ import shutil
343
+ if os.path.exists(output_path):
344
+ shutil.rmtree(output_path)
345
+ except:
346
+ pass
347
+
348
+
349
+ def _parse_deepseek_output(ocr_text: str) -> list:
350
+ """
351
+ Extract structured lines from DeepSeek-OCR markdown output.
352
+ Preserves layout, handles tables, lists, and structured content.
353
+ """
354
+ lines = []
355
+ text_lines = ocr_text.split('\n')
356
+
357
+ y_offset = 0
358
+ line_height = 24 # Estimated line height in pixels
359
+
360
+ for line_idx, line in enumerate(text_lines):
361
+ stripped = line.strip()
362
+ if not stripped:
363
+ # Empty lines still take space
364
+ y_offset += line_height // 2
365
+ continue
366
+
367
+ # Remove markdown formatting but preserve text structure
368
+ # Handle markdown tables (| separated)
369
+ if '|' in stripped and stripped.count('|') >= 2:
370
+ # Table row - split by | and process each cell
371
+ cells = [cell.strip() for cell in stripped.split('|') if cell.strip()]
372
+ for cell_idx, cell in enumerate(cells):
373
+ if cell:
374
+ lines.append({
375
+ "bbox": [
376
+ cell_idx * 200, # Approximate x position
377
+ y_offset,
378
+ (cell_idx + 1) * 200,
379
+ y_offset + line_height
380
+ ],
381
+ "text": cell,
382
+ "conf": 0.95,
383
+ })
384
+ y_offset += line_height
385
+ # Handle markdown lists (-, *, 1., etc.)
386
+ elif stripped.startswith(('-', '*', '+')) or (len(stripped) > 2 and stripped[1] == '.'):
387
+ # List item - remove list marker
388
+ text = stripped.lstrip('-*+').lstrip('0123456789.').strip()
389
+ if text:
390
+ lines.append({
391
+ "bbox": [40, y_offset, 1000, y_offset + line_height],
392
+ "text": text,
393
+ "conf": 0.95,
394
+ })
395
+ y_offset += line_height
396
+ # Handle headers (# ## ###)
397
+ elif stripped.startswith('#'):
398
+ header_level = len(stripped) - len(stripped.lstrip('#'))
399
+ text = stripped.lstrip('#').strip()
400
+ if text:
401
+ # Headers are typically larger
402
+ header_height = line_height + (header_level * 4)
403
+ lines.append({
404
+ "bbox": [0, y_offset, 1000, y_offset + header_height],
405
+ "text": text,
406
+ "conf": 0.95,
407
+ })
408
+ y_offset += header_height
409
+ # Regular text line
410
+ else:
411
+ # Estimate width based on text length (rough approximation)
412
+ estimated_width = min(len(stripped) * 8, 1000) # ~8px per char average
413
+ lines.append({
414
+ "bbox": [0, y_offset, estimated_width, y_offset + line_height],
415
+ "text": stripped,
416
+ "conf": 0.95,
417
+ })
418
+ y_offset += line_height
419
+
420
+ return lines
421
 
422
  api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
423
  _rate_limit_lock = asyncio.Lock()
 
425
 
426
 
427
  def ensure_upload_is_safe(file: UploadFile) -> None:
428
+ # Check content type from header
429
  content_type = (file.content_type or "").lower()
430
+
431
+ # Also check file extension as fallback (browsers sometimes send application/octet-stream)
432
+ filename = (file.filename or "").lower()
433
+ extension = filename.split('.')[-1] if '.' in filename else ""
434
+ allowed_extensions = {'jpg', 'jpeg', 'png', 'webp'}
435
+
436
+ # Allow if content type matches OR extension matches
437
+ content_type_valid = content_type in ALLOWED_CONTENT_TYPES
438
+ extension_valid = extension in allowed_extensions
439
+
440
+ if not content_type_valid and not extension_valid:
441
  raise HTTPException(
442
  status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
443
+ detail=f"Unsupported file type. Content-Type: {content_type}, Extension: {extension}. Allowed: {', '.join(ALLOWED_CONTENT_TYPES)}",
444
  )
445
 
446
  file.file.seek(0, os.SEEK_END)
 
453
  )
454
 
455
 
456
+ async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)) -> str:
457
+ # Skip API key verification in development mode
458
+ if not REQUIRE_API_KEY:
459
+ return api_key or SERVICE_API_KEY
460
+ # Enforce API key in production
461
  if not api_key or not secrets.compare_digest(api_key, SERVICE_API_KEY):
462
  raise HTTPException(
463
  status_code=status.HTTP_401_UNAUTHORIZED,
 
486
 
487
 
488
  def _decode_image(file: UploadFile):
489
+ """Decode uploaded image file to PIL Image"""
490
  data = file.file.read()
491
  if not data:
492
  raise HTTPException(
493
  status_code=status.HTTP_400_BAD_REQUEST,
494
  detail="Uploaded file is empty",
495
  )
496
+
497
+ # Save to temp file for DeepSeek-OCR
498
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
499
+ tmp_file.write(data)
500
+ tmp_path = tmp_file.name
501
+
502
+ try:
503
+ img = Image.open(tmp_path).convert("RGB")
504
+ return img, tmp_path
505
+ except Exception as e:
506
+ os.unlink(tmp_path)
507
  raise HTTPException(
508
  status_code=status.HTTP_400_BAD_REQUEST,
509
+ detail=f"Unable to decode image: {str(e)}",
510
  )
 
511
 
512
 
513
+ async def load_img(file: UploadFile):
514
  ensure_upload_is_safe(file)
515
  file.file.seek(0)
516
+ img, img_path = _decode_image(file)
517
+ return img, img_path
518
 
519
 
520
  def _parse_json_field(name: str, raw: str, expected_type: type) -> Any:
 
678
  return rules
679
 
680
 
681
+ @app.options("/ocr")
682
+ async def ocr_options():
683
+ """Handle CORS preflight requests (required by HuggingFace Spaces)"""
684
+ return {"message": "OK"}
685
+
686
+ @app.options("/api/predict")
687
+ async def predict_options():
688
+ """Handle CORS preflight for HuggingFace Spaces auto-routing"""
689
+ return {"message": "OK"}
690
+
691
  @app.post("/ocr")
692
+ @app.post("/api/predict") # HuggingFace Spaces may auto-route POST requests here
693
  async def ocr_page(
694
  file: UploadFile,
695
  _: None = Depends(enforce_rate_limit),
696
  ):
697
+ """OCR endpoint using DeepSeek-OCR"""
698
+ img, img_path = await load_img(file)
699
+ try:
700
+ # Save PIL image to temporary file for DeepSeek-OCR
701
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
702
+ img.save(tmp_file, 'JPEG', quality=95)
703
+ tmp_img_path = tmp_file.name
704
+
705
+ try:
706
+ # Use grounding prompt for better structure extraction
707
+ result = await run_deepseek_ocr(
708
+ tmp_img_path,
709
+ prompt="<image>\n<|grounding|>Convert the document to markdown with preserved layout.",
710
+ use_grounding=True
711
+ )
712
+ return result
713
+ except Exception as e:
714
+ # Log the error but don't crash - return a helpful error message
715
+ error_msg = str(e)
716
+ print(f"OCR processing error: {error_msg}")
717
+
718
+ # Check if it's a model loading issue
719
+ if "matplotlib" in error_msg or "torchvision" in error_msg or "ImportError" in error_msg:
720
+ raise HTTPException(
721
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
722
+ detail=f"OCR model dependencies missing: {error_msg}. Please install required packages."
723
+ )
724
+ elif "Connection" in error_msg or "timeout" in error_msg.lower():
725
+ raise HTTPException(
726
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
727
+ detail=f"OCR service temporarily unavailable: {error_msg}"
728
+ )
729
+ else:
730
+ raise HTTPException(
731
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
732
+ detail=f"OCR processing failed: {error_msg}"
733
+ )
734
+ finally:
735
+ if os.path.exists(tmp_img_path):
736
+ os.unlink(tmp_img_path)
737
+ finally:
738
+ if os.path.exists(img_path):
739
+ os.unlink(img_path)
740
 
741
 
742
  @app.post("/split")
 
749
  rules: str = Form("[]"),
750
  _: None = Depends(enforce_rate_limit),
751
  ):
752
+ """Split endpoint - uses DeepSeek-OCR for region extraction"""
753
+ img, img_path = await load_img(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  try:
755
+ width, height = img.size
756
+
757
+ # Save image for DeepSeek-OCR
758
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
759
+ img.save(tmp_file, 'JPEG', quality=95)
760
+ tmp_img_path = tmp_file.name
761
+
762
+ try:
763
+ parent_box = _parse_parent_bbox(parent_bbox, width, height)
764
+ x1, y1, x2, y2 = parent_box
765
+
766
+ # Crop image to parent bbox
767
+ crop_img = img.crop((int(x1), int(y1), int(x2), int(y2)))
768
+ crop_path = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg').name
769
+ crop_img.save(crop_path, 'JPEG', quality=95)
770
+
771
+ try:
772
+ # Use DeepSeek-OCR with grounding prompt for better structured extraction
773
+ prompt = "<image>\n<|grounding|>Convert the document region to markdown with preserved layout."
774
+ ocr_result = await run_deepseek_ocr(crop_path, prompt=prompt, use_grounding=True)
775
+
776
+ # Parse OCR result to extract lines
777
+ child_lines = ocr_result.get("lines", [])
778
+
779
+ # Adjust bboxes to parent coordinate space
780
+ for line in child_lines:
781
+ bbox = line["bbox"]
782
+ line["bbox"] = [
783
+ bbox[0] + x1,
784
+ bbox[1] + y1,
785
+ bbox[2] + x1,
786
+ bbox[3] + y1,
787
+ ]
788
+ line["blockType"] = "text"
789
+
790
+ if len(child_lines) > MAX_CHILD_LINES:
791
+ child_lines = child_lines[:MAX_CHILD_LINES]
792
+
793
+ sanitized_splitter = _sanitize_label("splitter", splitter)
794
+ sanitized_schema = _sanitize_label("schemaType", schemaType)
795
+ parsed_settings = _parse_settings(settings)
796
+ parsed_rules = _parse_rules(rules)
797
+
798
+ raw_text = "\n".join([l["text"] for l in child_lines])
799
+ text_truncated = False
800
+ if len(raw_text) > 5000:
801
+ raw_text = raw_text[:5000]
802
+ text_truncated = True
803
+
804
+ llm_input = {
805
+ "schemaType": sanitized_schema,
806
+ "splitter": sanitized_splitter,
807
+ "page": {"width": width, "height": height},
808
+ "parentBox": parent_box,
809
+ "rawText": raw_text,
810
+ "ocrLines": child_lines,
811
+ "rawTextTruncated": text_truncated,
812
+ "ocrLinesTruncated": len(child_lines) >= MAX_CHILD_LINES,
813
+ "settings": parsed_settings,
814
+ "rules": parsed_rules,
815
+ }
816
+
817
+ try:
818
+ llm_result = await call_llm_splitter(llm_input)
819
+ except ValueError as exc:
820
+ raise HTTPException(
821
+ status_code=status.HTTP_502_BAD_GATEWAY,
822
+ detail=str(exc),
823
+ ) from exc
824
+ return llm_result
825
+ finally:
826
+ if os.path.exists(crop_path):
827
+ os.unlink(crop_path)
828
+ finally:
829
+ if os.path.exists(tmp_img_path):
830
+ os.unlink(tmp_img_path)
831
+ finally:
832
+ if os.path.exists(img_path):
833
+ os.unlink(img_path)