make789 commited on
Commit
0e7bad9
·
verified ·
1 Parent(s): 0d047c5

Upload ocr_service.py

Browse files
Files changed (1) hide show
  1. ocr_service.py +96 -552
ocr_service.py CHANGED
@@ -2,59 +2,19 @@ import asyncio
2
  import json
3
  import math
4
  import os
5
- import platform
6
  import secrets
7
- import tempfile
8
  from collections import defaultdict, deque
9
  from time import monotonic
10
- from typing import Any, Deque, DefaultDict, Optional
11
- from pathlib import Path
12
 
 
13
  import numpy as np
14
  from fastapi import Depends, FastAPI, Form, HTTPException, Request, UploadFile, status
15
- from fastapi.middleware.cors import CORSMiddleware
16
  from fastapi.security import APIKeyHeader
17
- from PIL import Image
 
18
 
19
- # Lazy import DeepSeek-OCR dependencies (only load when needed)
20
- _torch = None
21
- _transformers = None
22
-
23
- def _get_torch():
24
- global _torch
25
- if _torch is None:
26
- try:
27
- import torch
28
- _torch = torch
29
- except ImportError:
30
- raise RuntimeError(
31
- "torch is not installed. Install with: pip install torch"
32
- )
33
- return _torch
34
-
35
- def _get_transformers():
36
- global _transformers
37
- if _transformers is None:
38
- try:
39
- from transformers import AutoModel, AutoTokenizer
40
- _transformers = (AutoModel, AutoTokenizer)
41
- except ImportError:
42
- raise RuntimeError(
43
- "transformers is not installed. Install with: pip install transformers"
44
- )
45
- return _transformers
46
-
47
- # Import llm_splitter (works as module or direct import)
48
- try:
49
- from llm_splitter import call_llm_splitter
50
- except ImportError:
51
- # Fallback for relative import
52
- try:
53
- from .llm_splitter import call_llm_splitter
54
- except ImportError:
55
- # If llm_splitter doesn't exist, define a stub
56
- async def call_llm_splitter(*args, **kwargs):
57
- raise NotImplementedError("llm_splitter not available")
58
 
59
  ALLOWED_CONTENT_TYPES = {
60
  "image/jpeg",
@@ -64,9 +24,11 @@ ALLOWED_CONTENT_TYPES = {
64
  MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", str(5 * 1024 * 1024)))
65
  RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "30"))
66
  RATE_LIMIT_WINDOW_SECONDS = float(os.getenv("RATE_LIMIT_WINDOW_SECONDS", "60"))
67
- # Allow API key to be optional for development (security risk in production!)
68
- SERVICE_API_KEY = os.getenv("SERVICE_API_KEY", "dev-key-change-in-production")
69
- REQUIRE_API_KEY = os.getenv("REQUIRE_API_KEY", "false").lower() == "true"
 
 
70
  API_KEY_HEADER_NAME = "X-API-Key"
71
  MAX_CHILD_LINES = 500
72
  MAX_JSON_DEPTH = 4
@@ -74,340 +36,9 @@ MAX_JSON_STRING_LENGTH = 512
74
  MAX_JSON_DICT_KEYS = 50
75
  MAX_JSON_LIST_ITEMS = 100
76
 
77
- # DeepSeek-OCR Model Configuration - Maximum Quality Settings for M4 Mac (Apple Silicon)
78
- MODEL_NAME = "deepseek-ai/DeepSeek-OCR"
79
- # PIN MODEL REVISION to prevent auto-updates that break compatibility
80
- # Use a commit hash from https://huggingface.co/deepseek-ai/DeepSeek-OCR/tree/main
81
- # This prevents "A new version of ... was downloaded" warnings and keeps code stable
82
- MODEL_REVISION = os.getenv("DEEPSEEK_MODEL_REVISION", "2c968b433af61a059311cbf8997765023806a24d") # Latest stable commit
83
-
84
- # Detect Apple Silicon (M1/M2/M3/M4) - use MPS if available, otherwise CPU
85
- IS_APPLE_SILICON = platform.machine() == "arm64"
86
- USE_GPU = os.getenv("USE_GPU", "true").lower() == "true" and not IS_APPLE_SILICON # M4 uses MPS, not CUDA
87
- USE_MPS = IS_APPLE_SILICON # Use Metal Performance Shaders on Apple Silicon
88
- # Maximum quality settings (larger = better, slower = more accurate)
89
- BASE_SIZE = int(os.getenv("DEEPSEEK_BASE_SIZE", "1280")) # Maximum quality: 1280 (not light!)
90
- IMAGE_SIZE = int(os.getenv("DEEPSEEK_IMAGE_SIZE", "1280")) # Maximum quality: 1280 (not light!)
91
- CROP_MODE = os.getenv("DEEPSEEK_CROP_MODE", "true").lower() == "true" # True for best accuracy
92
-
93
- app = FastAPI(
94
- title="DeepSeek-OCR API",
95
- description="OCR Service using DeepSeek-OCR for maximum quality text extraction",
96
- version="1.0.0"
97
- )
98
-
99
- # Add root endpoint for health check
100
- @app.get("/")
101
- async def root(__sign: Optional[str] = None):
102
- """
103
- Root endpoint - compatible with HuggingFace Spaces authentication.
104
- The __sign parameter is used by HuggingFace's proxy but can be ignored.
105
- """
106
- return {
107
- "service": "DeepSeek-OCR API",
108
- "status": "running",
109
- "version": "1.0.0",
110
- "endpoints": {
111
- "docs": "/docs",
112
- "ocr": "/ocr",
113
- "split": "/split"
114
- }
115
- }
116
-
117
- # Add CORS middleware to allow frontend requests
118
- # Configured for HuggingFace Spaces which requires explicit CORS handling
119
- app.add_middleware(
120
- CORSMiddleware,
121
- allow_origins=["*"], # In production, replace with specific origins
122
- allow_credentials=True,
123
- allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
124
- allow_headers=["*"],
125
- expose_headers=["*"],
126
- )
127
-
128
- # Initialize DeepSeek-OCR model
129
- _ocr_model = None
130
- _ocr_tokenizer = None
131
- _model_lock = asyncio.Lock()
132
-
133
-
134
- def _download_and_patch_model_locally(model_id: str, revision: str) -> str:
135
- """
136
- Download DeepSeek-OCR to a local dir, patch out hardcoded .cuda() in infer(),
137
- and return the local path for from_pretrained(...).
138
-
139
- This ensures we control the exact file that gets imported (patched version).
140
-
141
- Per official HuggingFace discussion: https://huggingface.co/deepseek-ai/DeepSeek-OCR/discussions/21
142
- """
143
- from pathlib import Path
144
- import re
145
-
146
- try:
147
- from huggingface_hub import snapshot_download
148
- except ImportError:
149
- raise RuntimeError("huggingface_hub is required. Install with: pip install huggingface_hub")
150
-
151
- print(f" 📥 Downloading model (revision {revision[:8]})...")
152
- local_dir = snapshot_download(model_id, revision=revision)
153
- print(f" ✅ Downloaded to: {local_dir}")
154
-
155
- # Find modeling_deepseekocr.py (may be in subdirectories with trust_remote_code)
156
- target = Path(local_dir) / "modeling_deepseekocr.py"
157
- if not target.exists():
158
- # trust_remote_code sometimes stores under modules; fall back to a glob
159
- hits = list(Path(local_dir).rglob("modeling_deepseekocr.py"))
160
- target = hits[0] if hits else None
161
-
162
- if not target or not target.exists():
163
- raise RuntimeError("Could not locate modeling_deepseekocr.py to patch")
164
-
165
- print(f" 🔍 Found model file: {target}")
166
-
167
- # Read file
168
- src = target.read_text(encoding='utf-8')
169
-
170
- # Check if already patched
171
- if ".unsqueeze(-1).cuda()" not in src and "input_ids.unsqueeze(0).cuda()" not in src:
172
- print(f" ✅ Model already patched for CPU")
173
- return local_dir
174
-
175
- # CPU patches from the HF discussion (remove hardcoded .cuda() in infer)
176
- # https://huggingface.co/deepseek-ai/DeepSeek-OCR/discussions/21
177
- repl = [
178
- (r"\.unsqueeze\(-1\)\.cuda\(\)", ".unsqueeze(-1)"),
179
- (r"input_ids\.unsqueeze\(0\)\.cuda\(\)", "input_ids.unsqueeze(0)"),
180
- (r"\(images_crop\.cuda\(\), images_ori\.cuda\(\)\)", "(images_crop, images_ori)"),
181
- (r"images_seq_mask = images_seq_mask\.unsqueeze\(0\)\.cuda\(\)",
182
- "images_seq_mask = images_seq_mask.unsqueeze(0)"),
183
- (r"input_ids\.unsqueeze\(0\)\.cuda\(\)\.shape\[1\]",
184
- "input_ids.unsqueeze(0).shape[1]"),
185
- ]
186
-
187
- original_src = src
188
- for pat, rep in repl:
189
- src = re.sub(pat, rep, src)
190
-
191
- # Write patched file
192
- if src != original_src:
193
- target.write_text(src, encoding='utf-8')
194
- print(f" ✅ Patched for CPU: {target}")
195
- else:
196
- print(f" ⚠️ No .cuda() calls found to patch")
197
-
198
- return local_dir
199
-
200
- async def get_ocr_model():
201
- """Lazy load DeepSeek-OCR model with compatibility patching"""
202
- global _ocr_model, _ocr_tokenizer
203
- if _ocr_model is None or _ocr_tokenizer is None:
204
- async with _model_lock:
205
- if _ocr_model is None or _ocr_tokenizer is None:
206
- # Lazy import dependencies
207
- # Note: Patching no longer needed - we pin transformers==4.46.3 and model revision
208
- AutoModel, AutoTokenizer = _get_transformers()
209
- torch = _get_torch()
210
-
211
- print(f"Loading DeepSeek-OCR model (MAXIMUM QUALITY): {MODEL_NAME}")
212
- print(f" - Base size: {BASE_SIZE}")
213
- print(f" - Image size: {IMAGE_SIZE}")
214
- print(f" - Crop mode: {CROP_MODE}")
215
-
216
- # 1) Download & patch; 2) Load from local dir so our patch is used
217
- local_dir = _download_and_patch_model_locally(MODEL_NAME, MODEL_REVISION)
218
-
219
- print(" - Loading tokenizer (local, pinned revision)...")
220
- _ocr_tokenizer = AutoTokenizer.from_pretrained(
221
- local_dir,
222
- trust_remote_code=True,
223
- local_files_only=True # Load from local patched directory
224
- )
225
- print(" - Tokenizer loaded successfully")
226
-
227
- # Load model with compatibility settings
228
- # Official DeepSeek-OCR usage: https://huggingface.co/deepseek-ai/DeepSeek-OCR
229
- # GPU version uses: attn_implementation='flash_attention_2', use_safetensors=True
230
- # CPU/Spaces version MUST use: attn_implementation='eager' (SDPA not implemented for this arch)
231
- # Fix per: https://github.com/huggingface/transformers/issues/28005
232
- load_kwargs = {
233
- "trust_remote_code": True,
234
- "use_safetensors": True,
235
- "attn_implementation": "eager", # SDPA not implemented for this arch
236
- }
237
-
238
- # If CUDA exists you can still cast later; but keep eager attention
239
- _ocr_model = AutoModel.from_pretrained(
240
- local_dir,
241
- local_files_only=True, # Load from local patched directory
242
- **load_kwargs
243
- ).eval()
244
-
245
- # Handle device placement (per official DeepSeek-OCR usage)
246
- if USE_MPS and torch.backends.mps.is_available():
247
- # Apple Silicon: MPS (Metal Performance Shaders)
248
- _ocr_model = _ocr_model.to("mps")
249
- print(" - DeepSeek-OCR loaded on Apple Silicon GPU (MPS/M4)")
250
- elif USE_GPU and torch.cuda.is_available():
251
- # NVIDIA GPU: CUDA with bfloat16 (per official usage)
252
- _ocr_model = _ocr_model.cuda().to(torch.bfloat16)
253
- print(" - DeepSeek-OCR loaded on NVIDIA GPU (CUDA + bfloat16)")
254
- else:
255
- # CPU: No device placement needed
256
- print(" - DeepSeek-OCR loaded on CPU")
257
- return _ocr_model, _ocr_tokenizer
258
-
259
-
260
- async def run_deepseek_ocr(
261
- image_path: str,
262
- prompt: str = "<image>\n<|grounding|>Convert the document to markdown with preserved layout.",
263
- use_grounding: bool = True
264
- ) -> dict:
265
- """
266
- Run DeepSeek-OCR on an image file with advanced grounding support.
267
-
268
- Genius enhancement: Uses grounding prompts for better structure extraction
269
- and layout preservation, following DeepSeek-OCR best practices.
270
- """
271
- model, tokenizer = await get_ocr_model()
272
-
273
- output_path = tempfile.mkdtemp()
274
-
275
- try:
276
- # OCR quality settings
277
- # Official DeepSeek-OCR quality presets (from https://huggingface.co/deepseek-ai/DeepSeek-OCR):
278
- # - Tiny: base_size=512, image_size=512, crop_mode=False
279
- # - Small: base_size=640, image_size=640, crop_mode=False
280
- # - Base: base_size=1024, image_size=1024, crop_mode=False
281
- # - Large: base_size=1280, image_size=1280, crop_mode=False (GPU only)
282
- # - Gundam: base_size=1024, image_size=640, crop_mode=True (recommended for CPU/ZeroGPU)
283
- #
284
- # Note: On CPU/Spaces, use lower quality to avoid slowdowns/OOM
285
- # We adjust based on device availability
286
- # Get torch to check for GPU availability
287
- torch = _get_torch()
288
- if USE_GPU and torch.cuda.is_available():
289
- # GPU: Use maximum quality (Large preset)
290
- actual_base_size = BASE_SIZE # 1280
291
- actual_image_size = IMAGE_SIZE # 1280
292
- else:
293
- # CPU/Spaces: Use Gundam preset (recommended for CPU to avoid OOM)
294
- actual_base_size = 1024
295
- actual_image_size = 640
296
- print(f" - Using CPU-optimized quality: base_size={actual_base_size}, image_size={actual_image_size}")
297
-
298
- result = model.infer(
299
- tokenizer,
300
- prompt=prompt,
301
- image_file=image_path,
302
- output_path=output_path,
303
- base_size=actual_base_size,
304
- image_size=actual_image_size,
305
- crop_mode=CROP_MODE, # True = best accuracy for complex documents
306
- save_results=False, # Don't save intermediate files
307
- test_compress=False, # False = maximum quality, no compression
308
- )
309
-
310
- # Parse result - DeepSeek-OCR returns structured markdown output
311
- ocr_text = result if isinstance(result, str) else str(result)
312
-
313
- # Genius parsing: Extract structured lines from markdown with better layout awareness
314
- lines = _parse_deepseek_output(ocr_text)
315
-
316
- return {
317
- "text": ocr_text,
318
- "lines": lines,
319
- }
320
- except Exception as e:
321
- print(f"DeepSeek-OCR error: {e}")
322
- import traceback
323
- traceback.print_exc()
324
- raise HTTPException(
325
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
326
- detail=f"OCR processing failed: {str(e)}",
327
- )
328
- finally:
329
- # Cleanup temp directory
330
- try:
331
- import shutil
332
- if os.path.exists(output_path):
333
- shutil.rmtree(output_path)
334
- except:
335
- pass
336
-
337
-
338
- def _parse_deepseek_output(ocr_text: str) -> list:
339
- """
340
- Genius parser: Extract structured lines from DeepSeek-OCR markdown output.
341
- Preserves layout, handles tables, lists, and structured content.
342
- """
343
- lines = []
344
- text_lines = ocr_text.split('\n')
345
-
346
- y_offset = 0
347
- line_height = 24 # Estimated line height in pixels
348
-
349
- for line_idx, line in enumerate(text_lines):
350
- stripped = line.strip()
351
- if not stripped:
352
- # Empty lines still take space
353
- y_offset += line_height // 2
354
- continue
355
-
356
- # Remove markdown formatting but preserve text structure
357
- # Handle markdown tables (| separated)
358
- if '|' in stripped and stripped.count('|') >= 2:
359
- # Table row - split by | and process each cell
360
- cells = [cell.strip() for cell in stripped.split('|') if cell.strip()]
361
- for cell_idx, cell in enumerate(cells):
362
- if cell:
363
- lines.append({
364
- "bbox": [
365
- cell_idx * 200, # Approximate x position
366
- y_offset,
367
- (cell_idx + 1) * 200,
368
- y_offset + line_height
369
- ],
370
- "text": cell,
371
- "conf": 0.95,
372
- })
373
- y_offset += line_height
374
- # Handle markdown lists (-, *, 1., etc.)
375
- elif stripped.startswith(('-', '*', '+')) or (len(stripped) > 2 and stripped[1] == '.'):
376
- # List item - remove list marker
377
- text = stripped.lstrip('-*+').lstrip('0123456789.').strip()
378
- if text:
379
- lines.append({
380
- "bbox": [40, y_offset, 1000, y_offset + line_height],
381
- "text": text,
382
- "conf": 0.95,
383
- })
384
- y_offset += line_height
385
- # Handle headers (# ## ###)
386
- elif stripped.startswith('#'):
387
- header_level = len(stripped) - len(stripped.lstrip('#'))
388
- text = stripped.lstrip('#').strip()
389
- if text:
390
- # Headers are typically larger
391
- header_height = line_height + (header_level * 4)
392
- lines.append({
393
- "bbox": [0, y_offset, 1000, y_offset + header_height],
394
- "text": text,
395
- "conf": 0.95,
396
- })
397
- y_offset += header_height
398
- # Regular text line
399
- else:
400
- # Estimate width based on text length (rough approximation)
401
- estimated_width = min(len(stripped) * 8, 1000) # ~8px per char average
402
- lines.append({
403
- "bbox": [0, y_offset, estimated_width, y_offset + line_height],
404
- "text": stripped,
405
- "conf": 0.95,
406
- })
407
- y_offset += line_height
408
-
409
- return lines
410
-
411
 
412
  api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
413
  _rate_limit_lock = asyncio.Lock()
@@ -415,22 +46,11 @@ _request_log: DefaultDict[str, Deque[float]] = defaultdict(deque)
415
 
416
 
417
  def ensure_upload_is_safe(file: UploadFile) -> None:
418
- # Check content type from header
419
  content_type = (file.content_type or "").lower()
420
-
421
- # Also check file extension as fallback (browsers sometimes send application/octet-stream)
422
- filename = (file.filename or "").lower()
423
- extension = filename.split('.')[-1] if '.' in filename else ""
424
- allowed_extensions = {'jpg', 'jpeg', 'png', 'webp'}
425
-
426
- # Allow if content type matches OR extension matches
427
- content_type_valid = content_type in ALLOWED_CONTENT_TYPES
428
- extension_valid = extension in allowed_extensions
429
-
430
- if not content_type_valid and not extension_valid:
431
  raise HTTPException(
432
  status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
433
- detail=f"Unsupported file type. Content-Type: {content_type}, Extension: {extension}. Allowed: {', '.join(ALLOWED_CONTENT_TYPES)}",
434
  )
435
 
436
  file.file.seek(0, os.SEEK_END)
@@ -443,11 +63,7 @@ def ensure_upload_is_safe(file: UploadFile) -> None:
443
  )
444
 
445
 
446
- async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)) -> str:
447
- # Skip API key verification in development mode
448
- if not REQUIRE_API_KEY:
449
- return api_key or SERVICE_API_KEY
450
- # Enforce API key in production
451
  if not api_key or not secrets.compare_digest(api_key, SERVICE_API_KEY):
452
  raise HTTPException(
453
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -475,36 +91,26 @@ async def enforce_rate_limit(
475
  window.append(now)
476
 
477
 
478
- def _decode_image(file: UploadFile) -> Image.Image:
479
- """Decode uploaded image file to PIL Image"""
480
  data = file.file.read()
481
  if not data:
482
  raise HTTPException(
483
  status_code=status.HTTP_400_BAD_REQUEST,
484
  detail="Uploaded file is empty",
485
  )
486
-
487
- # Save to temp file for DeepSeek-OCR
488
- with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
489
- tmp_file.write(data)
490
- tmp_path = tmp_file.name
491
-
492
- try:
493
- img = Image.open(tmp_path).convert("RGB")
494
- return img, tmp_path
495
- except Exception as e:
496
- os.unlink(tmp_path)
497
  raise HTTPException(
498
  status_code=status.HTTP_400_BAD_REQUEST,
499
- detail=f"Unable to decode image: {str(e)}",
500
  )
 
501
 
502
 
503
- async def load_img(file: UploadFile):
504
  ensure_upload_is_safe(file)
505
  file.file.seek(0)
506
- img, img_path = _decode_image(file)
507
- return img, img_path
508
 
509
 
510
  def _parse_json_field(name: str, raw: str, expected_type: type) -> Any:
@@ -668,65 +274,20 @@ def _parse_rules(raw: str) -> list:
668
  return rules
669
 
670
 
671
- @app.options("/ocr")
672
- async def ocr_options():
673
- """Handle CORS preflight requests (required by HuggingFace Spaces)"""
674
- return {"message": "OK"}
675
-
676
- @app.options("/api/predict")
677
- async def predict_options():
678
- """Handle CORS preflight for HuggingFace Spaces auto-routing"""
679
- return {"message": "OK"}
680
-
681
  @app.post("/ocr")
682
- @app.post("/api/predict") # HuggingFace Spaces may auto-route POST requests here
683
  async def ocr_page(
684
  file: UploadFile,
685
  _: None = Depends(enforce_rate_limit),
686
  ):
687
- """OCR endpoint using DeepSeek-OCR"""
688
- img, img_path = await load_img(file)
689
- try:
690
- # Save PIL image to temporary file for DeepSeek-OCR
691
- with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
692
- img.save(tmp_file, 'JPEG', quality=95)
693
- tmp_img_path = tmp_file.name
694
-
695
- try:
696
- # Use grounding prompt for better structure extraction
697
- result = await run_deepseek_ocr(
698
- tmp_img_path,
699
- prompt="<image>\n<|grounding|>Convert the document to markdown with preserved layout.",
700
- use_grounding=True
701
- )
702
- return result
703
- except Exception as e:
704
- # Log the error but don't crash - return a helpful error message
705
- error_msg = str(e)
706
- print(f"OCR processing error: {error_msg}")
707
-
708
- # Check if it's a model loading issue
709
- if "matplotlib" in error_msg or "torchvision" in error_msg or "ImportError" in error_msg:
710
- raise HTTPException(
711
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
712
- detail=f"OCR model dependencies missing: {error_msg}. Please install required packages."
713
- )
714
- elif "Connection" in error_msg or "timeout" in error_msg.lower():
715
- raise HTTPException(
716
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
717
- detail=f"OCR service temporarily unavailable: {error_msg}"
718
- )
719
- else:
720
- raise HTTPException(
721
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
722
- detail=f"OCR processing failed: {error_msg}"
723
- )
724
- finally:
725
- if os.path.exists(tmp_img_path):
726
- os.unlink(tmp_img_path)
727
- finally:
728
- if os.path.exists(img_path):
729
- os.unlink(img_path)
730
 
731
 
732
  @app.post("/split")
@@ -739,88 +300,71 @@ async def split(
739
  rules: str = Form("[]"),
740
  _: None = Depends(enforce_rate_limit),
741
  ):
742
- """Split endpoint - uses DeepSeek-OCR for region extraction"""
743
- img, img_path = await load_img(file)
744
- try:
745
- width, height = img.size
746
-
747
- # Save image for DeepSeek-OCR
748
- with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
749
- img.save(tmp_file, 'JPEG', quality=95)
750
- tmp_img_path = tmp_file.name
751
-
752
- try:
753
- parent_box = _parse_parent_bbox(parent_bbox, width, height)
754
- x1, y1, x2, y2 = parent_box
755
-
756
- # Crop image to parent bbox
757
- crop_img = img.crop((int(x1), int(y1), int(x2), int(y2)))
758
- crop_path = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg').name
759
- crop_img.save(crop_path, 'JPEG', quality=95)
760
-
761
- try:
762
- # Use DeepSeek-OCR with grounding prompt for better structured extraction
763
- prompt = "<image>\n<|grounding|>Convert the document region to markdown with preserved layout."
764
- ocr_result = await run_deepseek_ocr(crop_path, prompt=prompt, use_grounding=True)
765
-
766
- # Parse OCR result to extract lines
767
- child_lines = ocr_result.get("lines", [])
768
-
769
- # Adjust bboxes to parent coordinate space
770
- for line in child_lines:
771
- bbox = line["bbox"]
772
- line["bbox"] = [
773
- bbox[0] + x1,
774
- bbox[1] + y1,
775
- bbox[2] + x1,
776
- bbox[3] + y1,
777
- ]
778
- line["blockType"] = "text"
779
-
780
- if len(child_lines) > MAX_CHILD_LINES:
781
- child_lines = child_lines[:MAX_CHILD_LINES]
782
-
783
- sanitized_splitter = _sanitize_label("splitter", splitter)
784
- sanitized_schema = _sanitize_label("schemaType", schemaType)
785
- parsed_settings = _parse_settings(settings)
786
- parsed_rules = _parse_rules(rules)
787
-
788
- raw_text = "\n".join([l["text"] for l in child_lines])
789
- text_truncated = False
790
- if len(raw_text) > 5000:
791
- raw_text = raw_text[:5000]
792
- text_truncated = True
793
-
794
- llm_input = {
795
- "schemaType": sanitized_schema,
796
- "splitter": sanitized_splitter,
797
- "page": {"width": width, "height": height},
798
- "parentBox": parent_box,
799
- "rawText": raw_text,
800
- "ocrLines": child_lines,
801
- "rawTextTruncated": text_truncated,
802
- "ocrLinesTruncated": len(child_lines) >= MAX_CHILD_LINES,
803
- "settings": parsed_settings,
804
- "rules": parsed_rules,
805
  }
806
-
807
- try:
808
- llm_result = await call_llm_splitter(llm_input)
809
- except ValueError as exc:
810
- raise HTTPException(
811
- status_code=status.HTTP_502_BAD_GATEWAY,
812
- detail=str(exc),
813
- ) from exc
814
- return llm_result
815
- finally:
816
- if os.path.exists(crop_path):
817
- os.unlink(crop_path)
818
- finally:
819
- if os.path.exists(tmp_img_path):
820
- os.unlink(tmp_img_path)
821
- finally:
822
- if os.path.exists(img_path):
823
- os.unlink(img_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
 
825
 
826
  if __name__ == "__main__":
 
2
  import json
3
  import math
4
  import os
 
5
  import secrets
 
6
  from collections import defaultdict, deque
7
  from time import monotonic
8
+ from typing import Any, Deque, DefaultDict
 
9
 
10
+ import cv2
11
  import numpy as np
12
  from fastapi import Depends, FastAPI, Form, HTTPException, Request, UploadFile, status
 
13
  from fastapi.security import APIKeyHeader
14
+ from paddleocr import PaddleOCR
15
+ from paddleocr.ppstructure import PPStructure
16
 
17
+ from .llm_splitter import call_llm_splitter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  ALLOWED_CONTENT_TYPES = {
20
  "image/jpeg",
 
24
  MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", str(5 * 1024 * 1024)))
25
  RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "30"))
26
  RATE_LIMIT_WINDOW_SECONDS = float(os.getenv("RATE_LIMIT_WINDOW_SECONDS", "60"))
27
+ SERVICE_API_KEY = os.getenv("SERVICE_API_KEY")
28
+ if not SERVICE_API_KEY:
29
+ raise RuntimeError(
30
+ "SERVICE_API_KEY environment variable must be set before starting the OCR service"
31
+ )
32
  API_KEY_HEADER_NAME = "X-API-Key"
33
  MAX_CHILD_LINES = 500
34
  MAX_JSON_DEPTH = 4
 
36
  MAX_JSON_DICT_KEYS = 50
37
  MAX_JSON_LIST_ITEMS = 100
38
 
39
+ app = FastAPI()
40
+ ocr = PaddleOCR(use_angle_cls=True, lang="en")
41
+ pp = PPStructure(show_log=False, layout=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
44
  _rate_limit_lock = asyncio.Lock()
 
46
 
47
 
48
  def ensure_upload_is_safe(file: UploadFile) -> None:
 
49
  content_type = (file.content_type or "").lower()
50
+ if content_type not in ALLOWED_CONTENT_TYPES:
 
 
 
 
 
 
 
 
 
 
51
  raise HTTPException(
52
  status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
53
+ detail="Unsupported file type",
54
  )
55
 
56
  file.file.seek(0, os.SEEK_END)
 
63
  )
64
 
65
 
66
+ async def verify_api_key(api_key: str | None = Depends(api_key_header)) -> str:
 
 
 
 
67
  if not api_key or not secrets.compare_digest(api_key, SERVICE_API_KEY):
68
  raise HTTPException(
69
  status_code=status.HTTP_401_UNAUTHORIZED,
 
91
  window.append(now)
92
 
93
 
94
+ def _decode_image(file: UploadFile):
 
95
  data = file.file.read()
96
  if not data:
97
  raise HTTPException(
98
  status_code=status.HTTP_400_BAD_REQUEST,
99
  detail="Uploaded file is empty",
100
  )
101
+ img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR)
102
+ if img is None:
 
 
 
 
 
 
 
 
 
103
  raise HTTPException(
104
  status_code=status.HTTP_400_BAD_REQUEST,
105
+ detail="Unable to decode image",
106
  )
107
+ return img
108
 
109
 
110
+ def load_img(file: UploadFile):
111
  ensure_upload_is_safe(file)
112
  file.file.seek(0)
113
+ return _decode_image(file)
 
114
 
115
 
116
  def _parse_json_field(name: str, raw: str, expected_type: type) -> Any:
 
274
  return rules
275
 
276
 
 
 
 
 
 
 
 
 
 
 
277
  @app.post("/ocr")
 
278
  async def ocr_page(
279
  file: UploadFile,
280
  _: None = Depends(enforce_rate_limit),
281
  ):
282
+ img = load_img(file)
283
+ res = ocr.ocr(img, cls=True)
284
+ lines = []
285
+ full = []
286
+ for line in (res[0] or []):
287
+ (x1, y1, x2, y2), (txt, conf) = line
288
+ lines.append({"bbox": [x1, y1, x2, y2], "text": txt, "conf": float(conf)})
289
+ full.append(txt)
290
+ return {"text": "\n".join(full), "lines": lines}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
 
293
  @app.post("/split")
 
300
  rules: str = Form("[]"),
301
  _: None = Depends(enforce_rate_limit),
302
  ):
303
+ img = load_img(file)
304
+ height, width = img.shape[:2]
305
+ parent_box = _parse_parent_bbox(parent_bbox, width, height)
306
+ x1, y1, x2, y2 = parent_box
307
+ x1_i, y1_i, x2_i, y2_i = [int(round(v)) for v in parent_box]
308
+ crop = img[y1_i:y2_i, x1_i:x2_i]
309
+ if crop.size == 0:
310
+ raise HTTPException(
311
+ status_code=status.HTTP_400_BAD_REQUEST,
312
+ detail="parent_bbox region is empty",
313
+ )
314
+ blocks = pp(crop)
315
+
316
+ child_lines = []
317
+ for b in blocks:
318
+ bx1, by1, bx2, by2 = b["bbox"]
319
+ sub = crop[by1:by2, bx1:bx2]
320
+ det = ocr.ocr(sub, cls=True)
321
+ for ln in (det[0] or []):
322
+ (lx1, ly1, lx2, ly2), (txt, conf) = ln
323
+ child_lines.append(
324
+ {
325
+ "bbox": [lx1 + bx1 + x1, ly1 + by1 + y1, lx2 + bx1 + x1, ly2 + by1 + y1],
326
+ "text": txt,
327
+ "conf": float(conf),
328
+ "blockType": b.get("type", "text"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  }
330
+ )
331
+ if len(child_lines) >= MAX_CHILD_LINES:
332
+ break
333
+ if len(child_lines) >= MAX_CHILD_LINES:
334
+ break
335
+
336
+ sanitized_splitter = _sanitize_label("splitter", splitter)
337
+ sanitized_schema = _sanitize_label("schemaType", schemaType)
338
+ parsed_settings = _parse_settings(settings)
339
+ parsed_rules = _parse_rules(rules)
340
+
341
+ raw_text = "\n".join([l["text"] for l in child_lines])
342
+ text_truncated = False
343
+ if len(raw_text) > 5000:
344
+ raw_text = raw_text[:5000]
345
+ text_truncated = True
346
+
347
+ llm_input = {
348
+ "schemaType": sanitized_schema,
349
+ "splitter": sanitized_splitter,
350
+ "page": {"width": width, "height": height},
351
+ "parentBox": parent_box,
352
+ "rawText": raw_text,
353
+ "ocrLines": child_lines,
354
+ "rawTextTruncated": text_truncated,
355
+ "ocrLinesTruncated": len(child_lines) >= MAX_CHILD_LINES,
356
+ "settings": parsed_settings,
357
+ "rules": parsed_rules,
358
+ }
359
+
360
+ try:
361
+ llm_result = await call_llm_splitter(llm_input)
362
+ except ValueError as exc:
363
+ raise HTTPException(
364
+ status_code=status.HTTP_502_BAD_GATEWAY,
365
+ detail=str(exc),
366
+ ) from exc
367
+ return llm_result
368
 
369
 
370
  if __name__ == "__main__":