Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import math | |
| import os | |
| import platform | |
| import secrets | |
| import tempfile | |
| from collections import defaultdict, deque | |
| from time import monotonic | |
| from typing import Any, Deque, DefaultDict, Optional | |
| from pathlib import Path | |
| import numpy as np | |
| from fastapi import Depends, FastAPI, Form, HTTPException, Request, UploadFile, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security import APIKeyHeader | |
| from PIL import Image | |
| # Lazy import DeepSeek-OCR dependencies (only load when needed) | |
| _torch = None | |
| _transformers = None | |
| def _get_torch(): | |
| global _torch | |
| if _torch is None: | |
| try: | |
| import torch | |
| _torch = torch | |
| except ImportError: | |
| raise RuntimeError( | |
| "torch is not installed. Install with: pip install torch" | |
| ) | |
| return _torch | |
| def _get_transformers(): | |
| global _transformers | |
| if _transformers is None: | |
| try: | |
| from transformers import AutoModel, AutoTokenizer | |
| _transformers = (AutoModel, AutoTokenizer) | |
| except ImportError: | |
| raise RuntimeError( | |
| "transformers is not installed. Install with: pip install transformers" | |
| ) | |
| return _transformers | |
| # Import llm_splitter (works as module or direct import) | |
| try: | |
| from llm_splitter import call_llm_splitter | |
| except ImportError: | |
| # Fallback for relative import | |
| try: | |
| from .llm_splitter import call_llm_splitter | |
| except ImportError: | |
| # If llm_splitter doesn't exist, define a stub | |
| async def call_llm_splitter(*args, **kwargs): | |
| raise NotImplementedError("llm_splitter not available") | |
| ALLOWED_CONTENT_TYPES = { | |
| "image/jpeg", | |
| "image/png", | |
| "image/webp", | |
| } | |
| MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", str(5 * 1024 * 1024))) | |
| RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "30")) | |
| RATE_LIMIT_WINDOW_SECONDS = float(os.getenv("RATE_LIMIT_WINDOW_SECONDS", "60")) | |
| # Allow API key to be optional for development (security risk in production!) | |
| SERVICE_API_KEY = os.getenv("SERVICE_API_KEY", "dev-key-change-in-production") | |
| REQUIRE_API_KEY = os.getenv("REQUIRE_API_KEY", "false").lower() == "true" | |
| API_KEY_HEADER_NAME = "X-API-Key" | |
| MAX_CHILD_LINES = 500 | |
| MAX_JSON_DEPTH = 4 | |
| MAX_JSON_STRING_LENGTH = 512 | |
| MAX_JSON_DICT_KEYS = 50 | |
| MAX_JSON_LIST_ITEMS = 100 | |
| # DeepSeek-OCR Model Configuration - Maximum Quality Settings for M4 Mac (Apple Silicon) | |
| MODEL_NAME = "deepseek-ai/DeepSeek-OCR" | |
| # Detect Apple Silicon (M1/M2/M3/M4) - use MPS if available, otherwise CPU | |
| IS_APPLE_SILICON = platform.machine() == "arm64" | |
| USE_GPU = os.getenv("USE_GPU", "true").lower() == "true" and not IS_APPLE_SILICON # M4 uses MPS, not CUDA | |
| USE_MPS = IS_APPLE_SILICON # Use Metal Performance Shaders on Apple Silicon | |
| # Maximum quality settings (larger = better, slower = more accurate) | |
| BASE_SIZE = int(os.getenv("DEEPSEEK_BASE_SIZE", "1280")) # Maximum quality: 1280 (not light!) | |
| IMAGE_SIZE = int(os.getenv("DEEPSEEK_IMAGE_SIZE", "1280")) # Maximum quality: 1280 (not light!) | |
| CROP_MODE = os.getenv("DEEPSEEK_CROP_MODE", "true").lower() == "true" # True for best accuracy | |
| app = FastAPI() | |
| # Add CORS middleware to allow frontend requests | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, replace with specific origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize DeepSeek-OCR model | |
| _ocr_model = None | |
| _ocr_tokenizer = None | |
| _model_lock = asyncio.Lock() | |
| def _patch_deepseek_model_for_m4(): | |
| """ | |
| Patch DeepSeek-OCR model code to fix LlamaFlashAttention2 import error on M4 Mac. | |
| This is needed because transformers 4.57.1 doesn't have LlamaFlashAttention2, | |
| but DeepSeek-OCR's model code tries to import it. | |
| """ | |
| from pathlib import Path | |
| cache_dir = Path.home() / ".cache" / "huggingface" | |
| model_files = list(cache_dir.glob("**/modeling_deepseekv2.py")) | |
| if not model_files: | |
| return # Model not downloaded yet, will patch on first load | |
| model_file = model_files[0] | |
| # Check if already patched | |
| try: | |
| with open(model_file, 'r') as f: | |
| content = f.read() | |
| if "LlamaFlashAttention2 = LlamaAttention" in content: | |
| return # Already patched | |
| except: | |
| pass | |
| # Original import pattern | |
| original_import = """from transformers.models.llama.modeling_llama import ( | |
| LlamaAttention, | |
| LlamaFlashAttention2 | |
| )""" | |
| # Patched version with fallback | |
| patched_import = """from transformers.models.llama.modeling_llama import ( | |
| LlamaAttention, | |
| ) | |
| # Patch for M4 Mac: LlamaFlashAttention2 not available in transformers 4.57.1 | |
| # Use LlamaAttention as fallback when flash attention unavailable | |
| try: | |
| from transformers.models.llama.modeling_llama import LlamaFlashAttention2 | |
| except ImportError: | |
| # Fallback: Use LlamaAttention when flash attention not available | |
| LlamaFlashAttention2 = LlamaAttention""" | |
| try: | |
| if original_import in content: | |
| # Create backup | |
| backup_file = model_file.with_suffix('.py.backup') | |
| try: | |
| with open(backup_file, 'w') as f: | |
| f.write(content) | |
| except: | |
| pass | |
| # Apply patch | |
| content = content.replace(original_import, patched_import) | |
| with open(model_file, 'w') as f: | |
| f.write(content) | |
| print(f"✅ Patched DeepSeek model for M4 Mac compatibility") | |
| except Exception as e: | |
| print(f"⚠️ Could not patch model file: {e}") | |
| async def get_ocr_model(): | |
| """Lazy load DeepSeek-OCR model with M4 Mac compatibility patching""" | |
| global _ocr_model, _ocr_tokenizer | |
| if _ocr_model is None or _ocr_tokenizer is None: | |
| async with _model_lock: | |
| if _ocr_model is None or _ocr_tokenizer is None: | |
| # Patch DeepSeek model code for M4 Mac compatibility BEFORE loading | |
| _patch_deepseek_model_for_m4() | |
| # Lazy import dependencies | |
| AutoModel, AutoTokenizer = _get_transformers() | |
| torch = _get_torch() | |
| print(f"Loading DeepSeek-OCR model (MAXIMUM QUALITY): {MODEL_NAME}") | |
| print(f" - Base size: {BASE_SIZE} (maximum quality, not light version!)") | |
| print(f" - Image size: {IMAGE_SIZE} (maximum quality, not light version!)") | |
| print(f" - Crop mode: {CROP_MODE} (best accuracy)") | |
| _ocr_tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, trust_remote_code=True | |
| ) | |
| # Load model with Apple Silicon (M4) optimized settings | |
| # M4 Mac: Use SDPA (not flash_attention_2) - flash attention doesn't work on Apple Silicon | |
| load_kwargs = { | |
| "trust_remote_code": True, | |
| "use_safetensors": False, # Avoid safetensors issues on M4 | |
| } | |
| # Force SDPA attention for Apple Silicon compatibility | |
| # This avoids LlamaFlashAttention2 import errors on M4 Mac | |
| if IS_APPLE_SILICON: | |
| load_kwargs["_attn_implementation"] = "sdpa" | |
| print(" - Using SDPA attention (Apple Silicon/M4 optimized)") | |
| else: | |
| # For non-Apple Silicon, let model choose | |
| pass | |
| try: | |
| _ocr_model = AutoModel.from_pretrained(MODEL_NAME, **load_kwargs) | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f"⚠️ Model load error: {error_msg}") | |
| # If still fails, try minimal config | |
| if "LlamaFlashAttention2" in error_msg or "flash" in error_msg.lower(): | |
| print(" - Retrying with explicit SDPA attention...") | |
| load_kwargs_minimal = { | |
| "trust_remote_code": True, | |
| "use_safetensors": False, | |
| "_attn_implementation": "sdpa", # Force SDPA | |
| } | |
| _ocr_model = AutoModel.from_pretrained(MODEL_NAME, **load_kwargs_minimal) | |
| else: | |
| raise | |
| _ocr_model = _ocr_model.eval() | |
| # Handle device placement for M4 Mac (Apple Silicon) | |
| if USE_MPS and torch.backends.mps.is_available(): | |
| _ocr_model = _ocr_model.to("mps") | |
| print(" - DeepSeek-OCR loaded on Apple Silicon GPU (MPS/M4)") | |
| elif USE_GPU and torch.cuda.is_available(): | |
| _ocr_model = _ocr_model.cuda().to(torch.bfloat16) | |
| print(" - DeepSeek-OCR loaded on NVIDIA GPU") | |
| else: | |
| print(" - DeepSeek-OCR loaded on CPU") | |
| return _ocr_model, _ocr_tokenizer | |
| async def run_deepseek_ocr( | |
| image_path: str, | |
| prompt: str = "<image>\n<|grounding|>Convert the document to markdown with preserved layout.", | |
| use_grounding: bool = True | |
| ) -> dict: | |
| """ | |
| Run DeepSeek-OCR on an image file with advanced grounding support. | |
| Genius enhancement: Uses grounding prompts for better structure extraction | |
| and layout preservation, following DeepSeek-OCR best practices. | |
| """ | |
| model, tokenizer = await get_ocr_model() | |
| output_path = tempfile.mkdtemp() | |
| try: | |
| # Maximum quality inference - best OCR quality settings | |
| result = model.infer( | |
| tokenizer, | |
| prompt=prompt, | |
| image_file=image_path, | |
| output_path=output_path, | |
| base_size=BASE_SIZE, # 1280 = maximum quality (not light version!) | |
| image_size=IMAGE_SIZE, # 1280 = maximum quality (not light version!) | |
| crop_mode=CROP_MODE, # True = best accuracy for complex documents | |
| save_results=False, | |
| test_compress=False, # False = maximum quality, no compression | |
| ) | |
| # Parse result - DeepSeek-OCR returns structured markdown output | |
| ocr_text = result if isinstance(result, str) else str(result) | |
| # Genius parsing: Extract structured lines from markdown with better layout awareness | |
| lines = _parse_deepseek_output(ocr_text) | |
| return { | |
| "text": ocr_text, | |
| "lines": lines, | |
| } | |
| except Exception as e: | |
| print(f"DeepSeek-OCR error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"OCR processing failed: {str(e)}", | |
| ) | |
| finally: | |
| # Cleanup temp directory | |
| try: | |
| import shutil | |
| if os.path.exists(output_path): | |
| shutil.rmtree(output_path) | |
| except: | |
| pass | |
| def _parse_deepseek_output(ocr_text: str) -> list: | |
| """ | |
| Genius parser: Extract structured lines from DeepSeek-OCR markdown output. | |
| Preserves layout, handles tables, lists, and structured content. | |
| """ | |
| lines = [] | |
| text_lines = ocr_text.split('\n') | |
| y_offset = 0 | |
| line_height = 24 # Estimated line height in pixels | |
| for line_idx, line in enumerate(text_lines): | |
| stripped = line.strip() | |
| if not stripped: | |
| # Empty lines still take space | |
| y_offset += line_height // 2 | |
| continue | |
| # Remove markdown formatting but preserve text structure | |
| # Handle markdown tables (| separated) | |
| if '|' in stripped and stripped.count('|') >= 2: | |
| # Table row - split by | and process each cell | |
| cells = [cell.strip() for cell in stripped.split('|') if cell.strip()] | |
| for cell_idx, cell in enumerate(cells): | |
| if cell: | |
| lines.append({ | |
| "bbox": [ | |
| cell_idx * 200, # Approximate x position | |
| y_offset, | |
| (cell_idx + 1) * 200, | |
| y_offset + line_height | |
| ], | |
| "text": cell, | |
| "conf": 0.95, | |
| }) | |
| y_offset += line_height | |
| # Handle markdown lists (-, *, 1., etc.) | |
| elif stripped.startswith(('-', '*', '+')) or (len(stripped) > 2 and stripped[1] == '.'): | |
| # List item - remove list marker | |
| text = stripped.lstrip('-*+').lstrip('0123456789.').strip() | |
| if text: | |
| lines.append({ | |
| "bbox": [40, y_offset, 1000, y_offset + line_height], | |
| "text": text, | |
| "conf": 0.95, | |
| }) | |
| y_offset += line_height | |
| # Handle headers (# ## ###) | |
| elif stripped.startswith('#'): | |
| header_level = len(stripped) - len(stripped.lstrip('#')) | |
| text = stripped.lstrip('#').strip() | |
| if text: | |
| # Headers are typically larger | |
| header_height = line_height + (header_level * 4) | |
| lines.append({ | |
| "bbox": [0, y_offset, 1000, y_offset + header_height], | |
| "text": text, | |
| "conf": 0.95, | |
| }) | |
| y_offset += header_height | |
| # Regular text line | |
| else: | |
| # Estimate width based on text length (rough approximation) | |
| estimated_width = min(len(stripped) * 8, 1000) # ~8px per char average | |
| lines.append({ | |
| "bbox": [0, y_offset, estimated_width, y_offset + line_height], | |
| "text": stripped, | |
| "conf": 0.95, | |
| }) | |
| y_offset += line_height | |
| return lines | |
| api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) | |
| _rate_limit_lock = asyncio.Lock() | |
| _request_log: DefaultDict[str, Deque[float]] = defaultdict(deque) | |
| def ensure_upload_is_safe(file: UploadFile) -> None: | |
| # Check content type from header | |
| content_type = (file.content_type or "").lower() | |
| # Also check file extension as fallback (browsers sometimes send application/octet-stream) | |
| filename = (file.filename or "").lower() | |
| extension = filename.split('.')[-1] if '.' in filename else "" | |
| allowed_extensions = {'jpg', 'jpeg', 'png', 'webp'} | |
| # Allow if content type matches OR extension matches | |
| content_type_valid = content_type in ALLOWED_CONTENT_TYPES | |
| extension_valid = extension in allowed_extensions | |
| if not content_type_valid and not extension_valid: | |
| raise HTTPException( | |
| status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, | |
| detail=f"Unsupported file type. Content-Type: {content_type}, Extension: {extension}. Allowed: {', '.join(ALLOWED_CONTENT_TYPES)}", | |
| ) | |
| file.file.seek(0, os.SEEK_END) | |
| size = file.file.tell() | |
| file.file.seek(0) | |
| if size > MAX_UPLOAD_BYTES: | |
| raise HTTPException( | |
| status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, | |
| detail="Uploaded file exceeds size limit", | |
| ) | |
| async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)) -> str: | |
| # Skip API key verification in development mode | |
| if not REQUIRE_API_KEY: | |
| return api_key or SERVICE_API_KEY | |
| # Enforce API key in production | |
| if not api_key or not secrets.compare_digest(api_key, SERVICE_API_KEY): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid API key", | |
| ) | |
| return api_key | |
| async def enforce_rate_limit( | |
| request: Request, api_key: str = Depends(verify_api_key) | |
| ) -> None: | |
| if RATE_LIMIT_REQUESTS <= 0: | |
| return | |
| identifier = api_key or (request.client.host if request.client else "anonymous") | |
| now = monotonic() | |
| async with _rate_limit_lock: | |
| window = _request_log[identifier] | |
| while window and now - window[0] > RATE_LIMIT_WINDOW_SECONDS: | |
| window.popleft() | |
| if len(window) >= RATE_LIMIT_REQUESTS: | |
| raise HTTPException( | |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
| detail="Rate limit exceeded", | |
| ) | |
| window.append(now) | |
| def _decode_image(file: UploadFile) -> Image.Image: | |
| """Decode uploaded image file to PIL Image""" | |
| data = file.file.read() | |
| if not data: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Uploaded file is empty", | |
| ) | |
| # Save to temp file for DeepSeek-OCR | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | |
| tmp_file.write(data) | |
| tmp_path = tmp_file.name | |
| try: | |
| img = Image.open(tmp_path).convert("RGB") | |
| return img, tmp_path | |
| except Exception as e: | |
| os.unlink(tmp_path) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Unable to decode image: {str(e)}", | |
| ) | |
| async def load_img(file: UploadFile): | |
| ensure_upload_is_safe(file) | |
| file.file.seek(0) | |
| img, img_path = _decode_image(file) | |
| return img, img_path | |
| def _parse_json_field(name: str, raw: str, expected_type: type) -> Any: | |
| try: | |
| value = json.loads(raw) | |
| except json.JSONDecodeError as exc: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Invalid {name} payload", | |
| ) from exc | |
| if not isinstance(value, expected_type): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} must be a {expected_type.__name__}", | |
| ) | |
| return value | |
| def _validate_safe_json(value: Any, name: str, depth: int = 0) -> None: | |
| if depth > MAX_JSON_DEPTH: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} is too deeply nested", | |
| ) | |
| if isinstance(value, dict): | |
| if len(value) > MAX_JSON_DICT_KEYS: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} has too many keys", | |
| ) | |
| for key, item in value.items(): | |
| if not isinstance(key, str) or len(key) > 64: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} contains an invalid key", | |
| ) | |
| _validate_safe_json(item, f"{name}.{key}", depth + 1) | |
| return | |
| if isinstance(value, list): | |
| if len(value) > MAX_JSON_LIST_ITEMS: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} has too many entries", | |
| ) | |
| for idx, item in enumerate(value): | |
| _validate_safe_json(item, f"{name}[{idx}]", depth + 1) | |
| return | |
| if isinstance(value, str): | |
| if len(value) > MAX_JSON_STRING_LENGTH: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} contains an oversized string", | |
| ) | |
| if any(ord(ch) < 32 and ch not in (9, 10, 13) for ch in value): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} contains control characters", | |
| ) | |
| return | |
| if isinstance(value, bool) or value is None: | |
| return | |
| if isinstance(value, (int, float)): | |
| if isinstance(value, float) and not math.isfinite(value): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} must contain finite numbers", | |
| ) | |
| return | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} contains an unsupported value type", | |
| ) | |
| def _sanitize_label(name: str, value: str) -> str: | |
| if not isinstance(value, str): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} must be a string", | |
| ) | |
| trimmed = value.strip() | |
| if not trimmed: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} cannot be empty", | |
| ) | |
| if len(trimmed) > 128: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} is too long", | |
| ) | |
| if any(ord(ch) < 32 for ch in trimmed): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"{name} contains invalid characters", | |
| ) | |
| return trimmed | |
| def _parse_parent_bbox(raw: str, width: int, height: int) -> list[float]: | |
| values = _parse_json_field("parent_bbox", raw, list) | |
| if len(values) != 4: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="parent_bbox must have four values", | |
| ) | |
| coords: list[float] = [] | |
| for value in values: | |
| try: | |
| coord = float(value) | |
| except (TypeError, ValueError) as exc: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="parent_bbox must contain numeric values", | |
| ) from exc | |
| if not math.isfinite(coord): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="parent_bbox must contain finite coordinates", | |
| ) | |
| coords.append(coord) | |
| x1, y1, x2, y2 = coords | |
| if x2 <= x1 or y2 <= y1: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="parent_bbox coordinates are invalid", | |
| ) | |
| if x1 < 0 or y1 < 0 or x2 > width or y2 > height: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="parent_bbox is outside the image bounds", | |
| ) | |
| return coords | |
| def _parse_settings(raw: str) -> dict: | |
| settings = _parse_json_field("settings", raw, dict) | |
| if len(settings) > 50: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="settings payload is too large", | |
| ) | |
| _validate_safe_json(settings, "settings") | |
| return settings | |
| def _parse_rules(raw: str) -> list: | |
| rules = _parse_json_field("rules", raw, list) | |
| if len(rules) > 100: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="rules payload is too large", | |
| ) | |
| for idx, rule in enumerate(rules): | |
| if not isinstance(rule, dict): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="rules entries must be objects", | |
| ) | |
| _validate_safe_json(rule, f"rules[{idx}]") | |
| return rules | |
| async def ocr_page( | |
| file: UploadFile, | |
| _: None = Depends(enforce_rate_limit), | |
| ): | |
| """OCR endpoint using DeepSeek-OCR""" | |
| img, img_path = await load_img(file) | |
| try: | |
| # Save PIL image to temporary file for DeepSeek-OCR | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | |
| img.save(tmp_file, 'JPEG', quality=95) | |
| tmp_img_path = tmp_file.name | |
| try: | |
| # Use grounding prompt for better structure extraction | |
| result = await run_deepseek_ocr( | |
| tmp_img_path, | |
| prompt="<image>\n<|grounding|>Convert the document to markdown with preserved layout.", | |
| use_grounding=True | |
| ) | |
| return result | |
| except Exception as e: | |
| # Log the error but don't crash - return a helpful error message | |
| error_msg = str(e) | |
| print(f"OCR processing error: {error_msg}") | |
| # Check if it's a model loading issue | |
| if "matplotlib" in error_msg or "torchvision" in error_msg or "ImportError" in error_msg: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail=f"OCR model dependencies missing: {error_msg}. Please install required packages." | |
| ) | |
| elif "Connection" in error_msg or "timeout" in error_msg.lower(): | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail=f"OCR service temporarily unavailable: {error_msg}" | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"OCR processing failed: {error_msg}" | |
| ) | |
| finally: | |
| if os.path.exists(tmp_img_path): | |
| os.unlink(tmp_img_path) | |
| finally: | |
| if os.path.exists(img_path): | |
| os.unlink(img_path) | |
| async def split( | |
| file: UploadFile, | |
| parent_bbox: str = Form(...), | |
| splitter: str = Form(...), | |
| schemaType: str = Form(...), | |
| settings: str = Form("{}"), | |
| rules: str = Form("[]"), | |
| _: None = Depends(enforce_rate_limit), | |
| ): | |
| """Split endpoint - uses DeepSeek-OCR for region extraction""" | |
| img, img_path = await load_img(file) | |
| try: | |
| width, height = img.size | |
| # Save image for DeepSeek-OCR | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | |
| img.save(tmp_file, 'JPEG', quality=95) | |
| tmp_img_path = tmp_file.name | |
| try: | |
| parent_box = _parse_parent_bbox(parent_bbox, width, height) | |
| x1, y1, x2, y2 = parent_box | |
| # Crop image to parent bbox | |
| crop_img = img.crop((int(x1), int(y1), int(x2), int(y2))) | |
| crop_path = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg').name | |
| crop_img.save(crop_path, 'JPEG', quality=95) | |
| try: | |
| # Use DeepSeek-OCR with grounding prompt for better structured extraction | |
| prompt = "<image>\n<|grounding|>Convert the document region to markdown with preserved layout." | |
| ocr_result = await run_deepseek_ocr(crop_path, prompt=prompt, use_grounding=True) | |
| # Parse OCR result to extract lines | |
| child_lines = ocr_result.get("lines", []) | |
| # Adjust bboxes to parent coordinate space | |
| for line in child_lines: | |
| bbox = line["bbox"] | |
| line["bbox"] = [ | |
| bbox[0] + x1, | |
| bbox[1] + y1, | |
| bbox[2] + x1, | |
| bbox[3] + y1, | |
| ] | |
| line["blockType"] = "text" | |
| if len(child_lines) > MAX_CHILD_LINES: | |
| child_lines = child_lines[:MAX_CHILD_LINES] | |
| sanitized_splitter = _sanitize_label("splitter", splitter) | |
| sanitized_schema = _sanitize_label("schemaType", schemaType) | |
| parsed_settings = _parse_settings(settings) | |
| parsed_rules = _parse_rules(rules) | |
| raw_text = "\n".join([l["text"] for l in child_lines]) | |
| text_truncated = False | |
| if len(raw_text) > 5000: | |
| raw_text = raw_text[:5000] | |
| text_truncated = True | |
| llm_input = { | |
| "schemaType": sanitized_schema, | |
| "splitter": sanitized_splitter, | |
| "page": {"width": width, "height": height}, | |
| "parentBox": parent_box, | |
| "rawText": raw_text, | |
| "ocrLines": child_lines, | |
| "rawTextTruncated": text_truncated, | |
| "ocrLinesTruncated": len(child_lines) >= MAX_CHILD_LINES, | |
| "settings": parsed_settings, | |
| "rules": parsed_rules, | |
| } | |
| try: | |
| llm_result = await call_llm_splitter(llm_input) | |
| except ValueError as exc: | |
| raise HTTPException( | |
| status_code=status.HTTP_502_BAD_GATEWAY, | |
| detail=str(exc), | |
| ) from exc | |
| return llm_result | |
| finally: | |
| if os.path.exists(crop_path): | |
| os.unlink(crop_path) | |
| finally: | |
| if os.path.exists(tmp_img_path): | |
| os.unlink(tmp_img_path) | |
| finally: | |
| if os.path.exists(img_path): | |
| os.unlink(img_path) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8080) | |