| """ |
| Layout Detection Base Interface |
| |
| Defines the abstract interface for document layout detection. |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from typing import List, Optional, Dict, Any |
| from dataclasses import dataclass, field |
| from pydantic import BaseModel, Field |
| import numpy as np |
|
|
| from ..schemas.core import BoundingBox, LayoutRegion, LayoutType, OCRRegion |
|
|
|
|
| class LayoutConfig(BaseModel): |
| """Configuration for layout detection.""" |
| |
| method: str = Field( |
| default="rule_based", |
| description="Detection method: rule_based, paddle_structure, layoutlm" |
| ) |
|
|
| |
| min_confidence: float = Field( |
| default=0.5, |
| ge=0.0, |
| le=1.0, |
| description="Minimum confidence for detected regions" |
| ) |
|
|
| |
| detect_tables: bool = Field(default=True, description="Detect table regions") |
| detect_figures: bool = Field(default=True, description="Detect figure regions") |
| detect_headers: bool = Field(default=True, description="Detect header/footer") |
| detect_titles: bool = Field(default=True, description="Detect title/heading") |
| detect_lists: bool = Field(default=True, description="Detect list structures") |
|
|
| |
| merge_threshold: float = Field( |
| default=0.7, |
| ge=0.0, |
| le=1.0, |
| description="IoU threshold for merging overlapping regions" |
| ) |
|
|
| |
| use_gpu: bool = Field(default=True, description="Use GPU acceleration") |
| gpu_id: int = Field(default=0, ge=0, description="GPU device ID") |
|
|
| |
| table_min_rows: int = Field(default=2, ge=1, description="Minimum rows for table") |
| table_min_cols: int = Field(default=2, ge=1, description="Minimum columns for table") |
|
|
| |
| title_max_lines: int = Field(default=3, description="Max lines for title") |
| heading_font_ratio: float = Field( |
| default=1.2, |
| description="Font size ratio vs body text for headings" |
| ) |
|
|
|
|
| @dataclass |
| class LayoutResult: |
| """Result of layout detection for a page.""" |
| page: int |
| regions: List[LayoutRegion] = field(default_factory=list) |
| image_width: int = 0 |
| image_height: int = 0 |
| processing_time_ms: float = 0.0 |
|
|
| |
| success: bool = True |
| error: Optional[str] = None |
|
|
| def get_regions_by_type(self, layout_type: LayoutType) -> List[LayoutRegion]: |
| """Get regions of a specific type.""" |
| return [r for r in self.regions if r.type == layout_type] |
|
|
| def get_tables(self) -> List[LayoutRegion]: |
| """Get table regions.""" |
| return self.get_regions_by_type(LayoutType.TABLE) |
|
|
| def get_figures(self) -> List[LayoutRegion]: |
| """Get figure regions.""" |
| return self.get_regions_by_type(LayoutType.FIGURE) |
|
|
| def get_text_regions(self) -> List[LayoutRegion]: |
| """Get text-based regions (paragraph, title, heading, list).""" |
| text_types = { |
| LayoutType.TEXT, |
| LayoutType.TITLE, |
| LayoutType.HEADING, |
| LayoutType.PARAGRAPH, |
| LayoutType.LIST, |
| } |
| return [r for r in self.regions if r.type in text_types] |
|
|
|
|
| class LayoutDetector(ABC): |
| """ |
| Abstract base class for layout detectors. |
| """ |
|
|
| def __init__(self, config: Optional[LayoutConfig] = None): |
| """ |
| Initialize layout detector. |
| |
| Args: |
| config: Layout detection configuration |
| """ |
| self.config = config or LayoutConfig() |
| self._initialized = False |
|
|
| @abstractmethod |
| def initialize(self): |
| """Initialize the detector (load models, etc.).""" |
| pass |
|
|
| @abstractmethod |
| def detect( |
| self, |
| image: np.ndarray, |
| page_number: int = 0, |
| ocr_regions: Optional[List[OCRRegion]] = None, |
| ) -> LayoutResult: |
| """ |
| Detect layout regions in an image. |
| |
| Args: |
| image: Image as numpy array (RGB, HWC format) |
| page_number: Page number |
| ocr_regions: Optional OCR regions for text-aware detection |
| |
| Returns: |
| LayoutResult with detected regions |
| """ |
| pass |
|
|
| def detect_batch( |
| self, |
| images: List[np.ndarray], |
| page_numbers: Optional[List[int]] = None, |
| ocr_results: Optional[List[List[OCRRegion]]] = None, |
| ) -> List[LayoutResult]: |
| """ |
| Detect layout in multiple images. |
| |
| Args: |
| images: List of images |
| page_numbers: Optional page numbers |
| ocr_results: Optional OCR regions for each page |
| |
| Returns: |
| List of LayoutResult |
| """ |
| if page_numbers is None: |
| page_numbers = list(range(len(images))) |
| if ocr_results is None: |
| ocr_results = [None] * len(images) |
|
|
| results = [] |
| for img, page_num, ocr in zip(images, page_numbers, ocr_results): |
| results.append(self.detect(img, page_num, ocr)) |
| return results |
|
|
| @property |
| def name(self) -> str: |
| """Return detector name.""" |
| return self.__class__.__name__ |
|
|
| @property |
| def is_initialized(self) -> bool: |
| """Check if detector is initialized.""" |
| return self._initialized |
|
|
| def __enter__(self): |
| """Context manager entry.""" |
| if not self._initialized: |
| self.initialize() |
| return self |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| """Context manager exit.""" |
| pass |
|
|