| """ |
| Layout Detection Model Interface |
| |
| Abstract interface for document layout analysis models. |
| Detects regions like text blocks, tables, figures, headers, etc. |
| """ |
|
|
| from abc import abstractmethod |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from ..chunks.models import BoundingBox, ChunkType |
| from .base import ( |
| BaseModel, |
| BatchableModel, |
| ImageInput, |
| ModelCapability, |
| ModelConfig, |
| ) |
|
|
|
|
| class LayoutRegionType(str, Enum): |
| """Types of layout regions that can be detected.""" |
|
|
| |
| TEXT = "text" |
| TITLE = "title" |
| HEADING = "heading" |
| PARAGRAPH = "paragraph" |
| LIST = "list" |
|
|
| |
| TABLE = "table" |
| FIGURE = "figure" |
| CHART = "chart" |
| FORMULA = "formula" |
| CODE = "code" |
|
|
| |
| HEADER = "header" |
| FOOTER = "footer" |
| PAGE_NUMBER = "page_number" |
| CAPTION = "caption" |
| FOOTNOTE = "footnote" |
|
|
| |
| LOGO = "logo" |
| SIGNATURE = "signature" |
| STAMP = "stamp" |
| WATERMARK = "watermark" |
| FORM_FIELD = "form_field" |
| CHECKBOX = "checkbox" |
|
|
| |
| UNKNOWN = "unknown" |
|
|
| def to_chunk_type(self) -> ChunkType: |
| """Convert layout region type to chunk type.""" |
| mapping = { |
| LayoutRegionType.TEXT: ChunkType.TEXT, |
| LayoutRegionType.TITLE: ChunkType.TITLE, |
| LayoutRegionType.HEADING: ChunkType.HEADING, |
| LayoutRegionType.PARAGRAPH: ChunkType.PARAGRAPH, |
| LayoutRegionType.LIST: ChunkType.LIST, |
| LayoutRegionType.TABLE: ChunkType.TABLE, |
| LayoutRegionType.FIGURE: ChunkType.FIGURE, |
| LayoutRegionType.CHART: ChunkType.CHART, |
| LayoutRegionType.FORMULA: ChunkType.FORMULA, |
| LayoutRegionType.CODE: ChunkType.CODE, |
| LayoutRegionType.HEADER: ChunkType.HEADER, |
| LayoutRegionType.FOOTER: ChunkType.FOOTER, |
| LayoutRegionType.PAGE_NUMBER: ChunkType.PAGE_NUMBER, |
| LayoutRegionType.CAPTION: ChunkType.CAPTION, |
| LayoutRegionType.FOOTNOTE: ChunkType.FOOTNOTE, |
| LayoutRegionType.LOGO: ChunkType.LOGO, |
| LayoutRegionType.SIGNATURE: ChunkType.SIGNATURE, |
| LayoutRegionType.STAMP: ChunkType.STAMP, |
| LayoutRegionType.WATERMARK: ChunkType.WATERMARK, |
| LayoutRegionType.FORM_FIELD: ChunkType.FORM_FIELD, |
| LayoutRegionType.CHECKBOX: ChunkType.CHECKBOX, |
| } |
| return mapping.get(self, ChunkType.TEXT) |
|
|
|
|
| @dataclass |
| class LayoutConfig(ModelConfig): |
| """Configuration for layout detection models.""" |
|
|
| min_confidence: float = 0.5 |
| merge_overlapping: bool = True |
| overlap_threshold: float = 0.5 |
| detect_reading_order: bool = True |
| detect_columns: bool = True |
| region_types: Optional[List[LayoutRegionType]] = None |
|
|
| def __post_init__(self): |
| super().__post_init__() |
| if not self.name: |
| self.name = "layout_detector" |
|
|
|
|
| @dataclass |
| class LayoutRegion: |
| """A detected layout region.""" |
|
|
| region_type: LayoutRegionType |
| bbox: BoundingBox |
| confidence: float |
| region_id: str = "" |
|
|
| |
| reading_order: int = -1 |
|
|
| |
| parent_id: Optional[str] = None |
| child_ids: List[str] = field(default_factory=list) |
|
|
| |
| column_index: int = 0 |
| num_columns: int = 1 |
|
|
| |
| attributes: Dict[str, Any] = field(default_factory=dict) |
|
|
| def __post_init__(self): |
| if not self.region_id: |
| import hashlib |
| content = f"{self.region_type.value}_{self.bbox.xyxy}" |
| self.region_id = hashlib.md5(content.encode()).hexdigest()[:12] |
|
|
|
|
| @dataclass |
| class LayoutResult: |
| """Complete layout analysis result for a page.""" |
|
|
| regions: List[LayoutRegion] = field(default_factory=list) |
| reading_order: List[str] = field(default_factory=list) |
| num_columns: int = 1 |
| page_orientation: float = 0.0 |
| image_width: int = 0 |
| image_height: int = 0 |
| processing_time_ms: float = 0.0 |
| model_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
| def get_regions_by_type(self, region_type: LayoutRegionType) -> List[LayoutRegion]: |
| """Get all regions of a specific type.""" |
| return [r for r in self.regions if r.region_type == region_type] |
|
|
| def get_region_by_id(self, region_id: str) -> Optional[LayoutRegion]: |
| """Get a region by its ID.""" |
| for region in self.regions: |
| if region.region_id == region_id: |
| return region |
| return None |
|
|
| def get_ordered_regions(self) -> List[LayoutRegion]: |
| """Get regions in reading order.""" |
| if not self.reading_order: |
| |
| return sorted( |
| self.regions, |
| key=lambda r: (r.bbox.y_min, r.bbox.x_min) |
| ) |
|
|
| ordered = [] |
| for region_id in self.reading_order: |
| region = self.get_region_by_id(region_id) |
| if region: |
| ordered.append(region) |
| return ordered |
|
|
| def get_tables(self) -> List[LayoutRegion]: |
| """Get all table regions.""" |
| return self.get_regions_by_type(LayoutRegionType.TABLE) |
|
|
| def get_figures(self) -> List[LayoutRegion]: |
| """Get all figure regions.""" |
| return self.get_regions_by_type(LayoutRegionType.FIGURE) |
|
|
| def get_text_regions(self) -> List[LayoutRegion]: |
| """Get all text-based regions.""" |
| text_types = { |
| LayoutRegionType.TEXT, |
| LayoutRegionType.TITLE, |
| LayoutRegionType.HEADING, |
| LayoutRegionType.PARAGRAPH, |
| LayoutRegionType.LIST, |
| LayoutRegionType.CAPTION, |
| LayoutRegionType.FOOTNOTE, |
| } |
| return [r for r in self.regions if r.region_type in text_types] |
|
|
|
|
| class LayoutModel(BatchableModel): |
| """ |
| Abstract base class for layout detection models. |
| |
| Implementations should detect: |
| - Document regions (text, tables, figures, etc.) |
| - Reading order |
| - Column structure |
| - Region hierarchy |
| """ |
|
|
| def __init__(self, config: Optional[LayoutConfig] = None): |
| super().__init__(config or LayoutConfig(name="layout")) |
| self.config: LayoutConfig = self.config |
|
|
| def get_capabilities(self) -> List[ModelCapability]: |
| caps = [ModelCapability.LAYOUT_DETECTION] |
| if self.config.detect_reading_order: |
| caps.append(ModelCapability.READING_ORDER) |
| return caps |
|
|
| @abstractmethod |
| def detect( |
| self, |
| image: ImageInput, |
| **kwargs |
| ) -> LayoutResult: |
| """ |
| Detect layout regions in an image. |
| |
| Args: |
| image: Input document image |
| **kwargs: Additional parameters |
| |
| Returns: |
| LayoutResult with detected regions |
| """ |
| pass |
|
|
| def process_batch( |
| self, |
| inputs: List[ImageInput], |
| **kwargs |
| ) -> List[LayoutResult]: |
| """Process multiple images.""" |
| return [self.detect(img, **kwargs) for img in inputs] |
|
|
| def detect_tables( |
| self, |
| image: ImageInput, |
| **kwargs |
| ) -> List[LayoutRegion]: |
| """ |
| Detect only table regions. |
| |
| Convenience method that filters layout detection results. |
| """ |
| result = self.detect(image, **kwargs) |
| return result.get_tables() |
|
|
| def detect_figures( |
| self, |
| image: ImageInput, |
| **kwargs |
| ) -> List[LayoutRegion]: |
| """Detect only figure regions.""" |
| result = self.detect(image, **kwargs) |
| return result.get_figures() |
|
|
|
|
| class ReadingOrderModel(BaseModel): |
| """ |
| Abstract base class for reading order determination. |
| |
| Some implementations may be separate from layout detection, |
| requiring a specialized model for complex layouts. |
| """ |
|
|
| def get_capabilities(self) -> List[ModelCapability]: |
| return [ModelCapability.READING_ORDER] |
|
|
| @abstractmethod |
| def determine_order( |
| self, |
| regions: List[LayoutRegion], |
| image: Optional[ImageInput] = None, |
| **kwargs |
| ) -> List[str]: |
| """ |
| Determine reading order for a list of regions. |
| |
| Args: |
| regions: Layout regions to order |
| image: Optional image for visual cues |
| **kwargs: Additional parameters |
| |
| Returns: |
| List of region_ids in reading order |
| """ |
| pass |
|
|
|
|
| class HeuristicReadingOrderModel(ReadingOrderModel): |
| """ |
| Simple heuristic-based reading order model. |
| |
| Uses geometric analysis for column detection and ordering. |
| Suitable for simple document layouts. |
| """ |
|
|
| def __init__(self, config: Optional[ModelConfig] = None): |
| super().__init__(config or ModelConfig(name="heuristic_reading_order")) |
|
|
| def load(self) -> None: |
| self._is_loaded = True |
|
|
| def unload(self) -> None: |
| self._is_loaded = False |
|
|
| def determine_order( |
| self, |
| regions: List[LayoutRegion], |
| image: Optional[ImageInput] = None, |
| column_threshold: float = 0.3, |
| **kwargs |
| ) -> List[str]: |
| """ |
| Determine reading order using heuristics. |
| |
| Strategy: |
| 1. Detect columns based on x-coordinate clustering |
| 2. Within each column, sort top-to-bottom |
| 3. Process columns left-to-right |
| """ |
| if not regions: |
| return [] |
|
|
| |
| columns = self._detect_columns(regions, column_threshold) |
|
|
| |
| ordered_ids = [] |
| for column in columns: |
| column_regions = sorted(column, key=lambda r: r.bbox.y_min) |
| ordered_ids.extend(r.region_id for r in column_regions) |
|
|
| return ordered_ids |
|
|
| def _detect_columns( |
| self, |
| regions: List[LayoutRegion], |
| threshold: float |
| ) -> List[List[LayoutRegion]]: |
| """Detect columns by x-coordinate clustering.""" |
| if not regions: |
| return [] |
|
|
| |
| sorted_regions = sorted(regions, key=lambda r: r.bbox.x_min) |
|
|
| columns = [] |
| current_column = [sorted_regions[0]] |
|
|
| for region in sorted_regions[1:]: |
| |
| prev_region = current_column[-1] |
|
|
| |
| overlap_start = max(region.bbox.x_min, prev_region.bbox.x_min) |
| overlap_end = min(region.bbox.x_max, prev_region.bbox.x_max) |
|
|
| if overlap_end > overlap_start: |
| |
| current_column.append(region) |
| else: |
| |
| columns.append(current_column) |
| current_column = [region] |
|
|
| columns.append(current_column) |
| return columns |
|
|