Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| """Core radiology report structuring functionality using LangExtract. | |
| This module provides the RadiologyReportStructurer class that processes raw | |
| radiology reports into structured segments categorized as prefix, body, or suffix sections with clinical significance annotations (normal, minor, significant). | |
| The structuring uses LangExtract with example-guided prompting to extract segments with character intervals that enable interactive hover-to-highlight functionality in the web frontend. | |
| Backend-Frontend Integration: | |
| - Backend generates segments with character intervals (startPos/endPos) | |
| - Frontend creates interactive spans that highlight corresponding input text on hover | |
| - Significance levels drive CSS styling for visual differentiation | |
| - Segment types organize content into structured sections (EXAMINATION, FINDINGS, IMPRESSION) | |
| Example usage: | |
| structurer = RadiologyReportStructurer( | |
| api_key="your_api_key", | |
| model_id="gemini-2.5-flash" | |
| ) | |
| result = structurer.predict("FINDINGS: Normal chest CT...") | |
| """ | |
| import collections | |
| import dataclasses | |
| import itertools | |
| from enum import Enum | |
| from functools import wraps | |
| from typing import Any, TypedDict | |
| import langextract as lx | |
| import langextract.data | |
| import prompt_instruction | |
| import prompt_lib | |
| import report_examples | |
| class FrontendIntervalDict(TypedDict): | |
| """Character interval for frontend with startPos and endPos.""" | |
| startPos: int | |
| endPos: int | |
| class SegmentDict(TypedDict): | |
| """Segment dictionary for JSON response.""" | |
| type: str | |
| label: str | None | |
| content: str | |
| intervals: list[FrontendIntervalDict] | |
| significance: str | None | |
| class SerializedExtractionDict(TypedDict): | |
| """Serialized extraction for JSON response.""" | |
| extraction_text: str | None | |
| extraction_class: str | None | |
| attributes: dict[str, str] | None | |
| char_interval: dict[str, int | None] | None | |
| alignment_status: str | None | |
| class ResponseDict(TypedDict): | |
| """Complete response dictionary structure.""" | |
| segments: list[SegmentDict] | |
| annotated_document_json: dict[str, Any] | |
| text: str | |
| raw_prompt: str | |
| FINDINGS_HEADER = "FINDINGS:" | |
| IMPRESSION_HEADER = "IMPRESSION:" | |
| EXAMINATION_HEADER = "EXAMINATION:" | |
| SECTION_ATTRIBUTE_KEY = "section" | |
| START_POSITION = "startPos" | |
| END_POSITION = "endPos" | |
| EXAM_PREFIXES = ("EXAMINATION:", "EXAM:", "STUDY:") | |
| EXAMINATION_LABEL = "examination" | |
| PREFIX_LABEL = "prefix" | |
| SIGNIFICANCE_NORMAL = "normal" | |
| SIGNIFICANCE_MINOR = "minor" | |
| SIGNIFICANCE_SIGNIFICANT = "significant" | |
| SIGNIFICANCE_NOT_APPLICABLE = "not_applicable" | |
| def _initialize_langextract_patches(): | |
| """Initialize LangExtract patches for proper alignment behavior. | |
| This function applies necessary patches to LangExtract's Resolver.align method to force accept_match_lesser=False and set fuzzy_alignment_threshold to 0.50. This should be called before using LangExtract functionality. | |
| Note: This is a temporary workaround until LangExtract exposes | |
| accept_match_lesser and fuzzy_alignment_threshold parameters via its public API. | |
| """ | |
| # Store original method | |
| original_align = lx.resolver.Resolver.align | |
| def _align_patched(self, *args, **kwargs): | |
| # Set default if not explicitly provided | |
| kwargs.setdefault("accept_match_lesser", False) | |
| # Set fuzzy matching threshold to 0.50 | |
| kwargs.setdefault("fuzzy_alignment_threshold", 0.50) | |
| return original_align(self, *args, **kwargs) | |
| # Apply the patch | |
| lx.resolver.Resolver.align = _align_patched | |
| class ReportSectionType(Enum): | |
| """Enum representing sections of a radiology report with their extraction class names.""" | |
| PREFIX = "findings_prefix" | |
| BODY = "findings_body" | |
| SUFFIX = "findings_suffix" | |
| def display_name(self) -> str: | |
| """Returns the lowercase section type name for display purposes.""" | |
| return self.name.lower() | |
| class Segment: | |
| """Represents a single merged segment of text in the final structured report. | |
| Attributes: | |
| type: The section type (prefix, body, or suffix). | |
| label: Optional section label for organization. | |
| content: The text content of this segment. | |
| intervals: List of character position intervals. | |
| significance: Optional clinical significance indicator. | |
| """ | |
| type: ReportSectionType | |
| label: str | None | |
| content: str | |
| intervals: list[FrontendIntervalDict] | |
| significance: str | None = None | |
| def to_dict(self) -> SegmentDict: | |
| """Converts the segment to a dictionary representation. | |
| Returns: | |
| A dictionary containing all segment data with type as display name. | |
| """ | |
| return SegmentDict( | |
| type=self.type.display_name, | |
| label=self.label, | |
| content=self.content, | |
| intervals=self.intervals, | |
| significance=self.significance, | |
| ) | |
| class RadiologyReportStructurer: | |
| """Structures radiology reports using LangExtract and large language models. | |
| This class processes raw radiology report text and converts it | |
| into structured segments categorized as prefix, body, or suffix | |
| sections with appropriate labeling and clinical significance annotations. | |
| """ | |
| api_key: str | None | |
| model_id: str | |
| temperature: float | |
| examples: list[langextract.data.ExampleData] | |
| _patches_initialized: bool | |
| def __init__( | |
| self, | |
| api_key: str | None = None, | |
| model_id: str = "gemini-2.5-flash", | |
| temperature: float = 0.0, | |
| ): | |
| """Initializes the RadiologyReportStructurer. | |
| Args: | |
| api_key: API key for the language model service. | |
| model_id: Identifier for the specific model to use. | |
| temperature: Sampling temperature for model generation. | |
| """ | |
| self.api_key = api_key | |
| self.model_id = model_id | |
| self.temperature = temperature | |
| self.examples = report_examples.get_examples_for_model() | |
| self._patches_initialized = False | |
| def _ensure_patches_initialized(self): | |
| """Ensure LangExtract patches are initialized before use.""" | |
| if not self._patches_initialized: | |
| _initialize_langextract_patches() | |
| self._patches_initialized = True | |
| def _generate_formatted_prompt_with_examples( | |
| self, input_text: str | None = None | |
| ) -> str: | |
| """Generates a comprehensive, markdown-formatted prompt including examples. | |
| Args: | |
| input_text: Optional input text to include in the prompt display. | |
| Returns: | |
| A markdown-formatted string containing the full prompt and examples. | |
| """ | |
| return prompt_lib.generate_markdown_prompt(self.examples, input_text) | |
| def predict(self, report_text: str, max_char_buffer: int = 2000) -> ResponseDict: | |
| """Processes a radiology report text into structured format. | |
| Takes raw radiology report text and uses LangExtract with example-guided | |
| prompting to extract structured segments with character intervals and | |
| clinical significance annotations. | |
| Args: | |
| report_text: Raw radiology report text to be processed. | |
| max_char_buffer: Maximum character buffer size for processing. | |
| Returns: | |
| A dictionary containing: | |
| - segments: List of structured report segments | |
| - annotated_document_json: Raw extraction results | |
| - text: Formatted text representation | |
| Raises: | |
| ValueError: If report_text is empty or whitespace-only. | |
| """ | |
| if not report_text.strip(): | |
| raise ValueError("Report text cannot be empty") | |
| try: | |
| result = self._perform_langextract(report_text, max_char_buffer) | |
| return self._build_response(result, report_text) | |
| except (ValueError, TypeError, AttributeError) as e: | |
| return ResponseDict( | |
| text=f"Error processing report: {str(e)}", | |
| segments=[], | |
| annotated_document_json={}, | |
| raw_prompt="", | |
| ) | |
| def _perform_langextract( | |
| self, report_text: str, max_char_buffer: int | |
| ) -> langextract.data.AnnotatedDocument: | |
| """Performs LangExtract processing on the input text. | |
| Args: | |
| report_text: Raw radiology report text to be processed. | |
| max_char_buffer: Maximum character buffer size for processing. | |
| Returns: | |
| LangExtract result object containing extractions. | |
| Raises: | |
| ValueError: If LangExtract processing fails. | |
| TypeError: If invalid parameters are provided. | |
| """ | |
| self._ensure_patches_initialized() | |
| return lx.extract( | |
| text_or_documents=report_text, | |
| prompt_description=prompt_instruction.PROMPT_INSTRUCTION.split( | |
| "# Few-Shot Examples" | |
| )[0], | |
| examples=self.examples, | |
| model_id=self.model_id, | |
| api_key=self.api_key, | |
| max_char_buffer=max_char_buffer, | |
| temperature=self.temperature, | |
| # accept_match_lesser handled via monkey-patch | |
| # (Resolver.align patched at import time) | |
| ) | |
| def _build_response( | |
| self, result: langextract.data.AnnotatedDocument, report_text: str | |
| ) -> ResponseDict: | |
| """Builds the final response dictionary from LangExtract results. | |
| Args: | |
| result: LangExtract result object containing extractions. | |
| report_text: Original input text for prompt generation. | |
| Returns: | |
| Dictionary containing structured segments and metadata. | |
| """ | |
| segments = self._build_segments_from_langextract_result(result) | |
| organized_segments = self._organize_segments_by_label(segments) | |
| response: ResponseDict = { | |
| "segments": [segment.to_dict() for segment in organized_segments], | |
| "annotated_document_json": self._serialize_extraction_results(result), | |
| "text": self._format_segments_to_text(organized_segments), | |
| "raw_prompt": self._generate_formatted_prompt_with_examples(report_text), | |
| } | |
| return response | |
| def _serialize_extraction_results( | |
| self, result: langextract.data.AnnotatedDocument | |
| ) -> dict[str, Any]: | |
| """Serializes LangExtract results for JSON response. | |
| Args: | |
| result: LangExtract result object containing extractions. | |
| Returns: | |
| Dictionary containing serialized extraction data or error information. | |
| """ | |
| try: | |
| if not hasattr(result, "extractions"): | |
| return {"error": "No extractions found in result"} | |
| return { | |
| "extractions": [ | |
| self._serialize_single_extraction(extraction) | |
| for extraction in result.extractions | |
| ] | |
| } | |
| except (AttributeError, TypeError, KeyError) as e: | |
| return { | |
| "error": "Failed to serialize extraction result", | |
| "error_message": str(e), | |
| "fallback_string": str(result), | |
| } | |
| def _serialize_single_extraction( | |
| self, extraction: langextract.data.Extraction | |
| ) -> SerializedExtractionDict: | |
| """Serializes a single extraction to dictionary format.""" | |
| return { | |
| "extraction_text": extraction.extraction_text, | |
| "extraction_class": extraction.extraction_class, | |
| "attributes": extraction.attributes, | |
| "char_interval": self._extract_char_interval(extraction), | |
| "alignment_status": self._get_alignment_status_string(extraction), | |
| } | |
| def _get_alignment_status_string( | |
| self, extraction: langextract.data.Extraction | |
| ) -> str | None: | |
| """Extracts alignment status from extraction as string.""" | |
| status = getattr(extraction, "alignment_status", None) | |
| return str(status) if status is not None else None | |
| def _build_segments_from_langextract_result( | |
| self, result: langextract.data.AnnotatedDocument | |
| ) -> list[Segment]: | |
| """Builds segments from LangExtract result data using one-segment-per-interval strategy. | |
| Creates exactly one segment per character interval to enable precise | |
| frontend hover-to-highlight functionality. Processes only | |
| langextract.data.Extraction objects for consistent typing. | |
| Args: | |
| result: LangExtract result object containing extractions. | |
| Returns: | |
| List of Segment objects optimized for frontend rendering and interaction. | |
| """ | |
| segments_list = [] | |
| for extraction in result.extractions: | |
| section_type = self._map_section(extraction.extraction_class) | |
| if section_type is None: | |
| continue | |
| section_label = self._determine_section_label( | |
| extraction.attributes, section_type | |
| ) | |
| significance_val = self._extract_clinical_significance( | |
| extraction.attributes | |
| ) | |
| intervals = self._get_intervals_from_extraction_dict( | |
| extraction, extraction.char_interval | |
| ) | |
| segments_list.extend( | |
| self._create_segments_for_intervals( | |
| section_type, | |
| section_label, | |
| extraction.extraction_text, | |
| intervals, | |
| significance_val, | |
| ) | |
| ) | |
| return segments_list | |
| def _determine_section_label( | |
| self, | |
| attributes: dict[str, str] | None, | |
| section_type: ReportSectionType, | |
| ) -> str: | |
| """Determines the appropriate section label for a segment.""" | |
| if attributes and isinstance(attributes, dict): | |
| section_label = attributes.get(SECTION_ATTRIBUTE_KEY) | |
| if section_label: | |
| return section_label | |
| return section_type.display_name | |
| def _extract_clinical_significance( | |
| self, attributes: dict[str, str] | None | |
| ) -> str | None: | |
| """Extracts clinical significance from attributes safely.""" | |
| if not attributes or not isinstance(attributes, dict): | |
| return None | |
| try: | |
| sig_raw = attributes.get("clinical_significance") | |
| if sig_raw is not None: | |
| return getattr(sig_raw, "value", str(sig_raw)).lower() | |
| except (AttributeError, TypeError): | |
| pass | |
| return None | |
| def _create_segments_for_intervals( | |
| self, | |
| section_type: ReportSectionType, | |
| section_label: str, | |
| content: str, | |
| intervals: list[FrontendIntervalDict], | |
| significance: str | None, | |
| ) -> list[Segment]: | |
| """Creates segment objects for the given intervals.""" | |
| if not intervals: | |
| return [ | |
| Segment( | |
| type=section_type, | |
| label=section_label, | |
| content=content, | |
| intervals=[], | |
| significance=significance, | |
| ) | |
| ] | |
| return [ | |
| Segment( | |
| type=section_type, | |
| label=section_label, | |
| content=content, | |
| intervals=[interval], | |
| significance=significance, | |
| ) | |
| for interval in intervals | |
| ] | |
| def _map_section(self, extraction_class: str) -> ReportSectionType | None: | |
| """Maps extraction class string to ReportSectionType enum.""" | |
| extraction_class = extraction_class.lower().strip() | |
| for section_type in ReportSectionType: | |
| if section_type.value == extraction_class: | |
| return section_type | |
| return None | |
| def _get_intervals_from_extraction_dict( | |
| self, | |
| extraction: langextract.data.Extraction, | |
| char_interval: langextract.data.CharInterval | dict[str, int] | None = None, | |
| ) -> list[FrontendIntervalDict]: | |
| """Extracts character intervals from extraction data. | |
| Returns a list of interval dictionaries from the extraction's | |
| char_interval in the format expected by the frontend. | |
| Args: | |
| extraction: langextract.data.Extraction object containing interval data. | |
| char_interval: Optional override for character interval data. | |
| Returns: | |
| List of dictionaries with startPos and endPos keys. | |
| """ | |
| interval_list = [] | |
| try: | |
| char_interval = ( | |
| char_interval if char_interval is not None else extraction.char_interval | |
| ) | |
| if char_interval is not None: | |
| # Handle both dict and object formats for char_interval (langextract.data.CharInterval object or dict override) | |
| if isinstance(char_interval, dict): | |
| start_pos = char_interval.get("start_pos") | |
| end_pos = char_interval.get("end_pos") | |
| else: | |
| start_pos = getattr(char_interval, "start_pos", None) | |
| end_pos = getattr(char_interval, "end_pos", None) | |
| start_position, end_position = self._extract_positions( | |
| start_pos, end_pos | |
| ) | |
| if start_position is not None and end_position is not None: | |
| interval_list.append( | |
| FrontendIntervalDict( | |
| startPos=start_position, endPos=end_position | |
| ) | |
| ) | |
| except Exception: | |
| pass | |
| return interval_list | |
| def _extract_positions(self, start_obj, end_obj) -> tuple[int | None, int | None]: | |
| """Extracts position integers from potentially complex objects. | |
| Handles possible slice objects or direct integers for start and end positions. | |
| """ | |
| if hasattr(start_obj, "start"): | |
| start_obj = start_obj.start | |
| if hasattr(end_obj, "stop"): | |
| end_obj = end_obj.stop | |
| try: | |
| start_position = int(start_obj) if start_obj is not None else None | |
| end_position = int(end_obj) if end_obj is not None else None | |
| if start_position is not None and end_position is not None: | |
| return (start_position, end_position) | |
| except Exception: | |
| pass | |
| return (None, None) | |
| def _extract_char_interval( | |
| self, extraction: langextract.data.Extraction | |
| ) -> dict[str, int | None] | None: | |
| """Extracts character interval information from an extraction.""" | |
| char_interval = extraction.char_interval | |
| if char_interval is None: | |
| return None | |
| return { | |
| "start_pos": getattr(char_interval, "start_pos", None), | |
| "end_pos": getattr(char_interval, "end_pos", None), | |
| } | |
| def _format_segments_to_text(self, segments: list[Segment]) -> str: | |
| """Formats segments into a readable text representation. | |
| Merges segments with the same label into coherent paragraphs | |
| while preserving the original order of labels as they appear | |
| in the document. | |
| """ | |
| grouped = self._group_segments_by_type_and_label(segments) | |
| formatted_parts: list[str] = [] | |
| self._render_prefix_sections(grouped, segments, formatted_parts) | |
| self._render_body_sections(grouped, formatted_parts) | |
| self._render_suffix_sections(grouped, formatted_parts) | |
| return "\n".join(formatted_parts).rstrip() | |
| def _group_segments_by_type_and_label( | |
| self, segments: list[Segment] | |
| ) -> collections.OrderedDict[tuple[ReportSectionType, str | None], list[str]]: | |
| """Groups segments by (type, label) preserving insertion order. | |
| Creates a dictionary keyed by (ReportSectionType, label) tuples | |
| that maintains the order segments are first encountered. | |
| Deduplicates content within each group while preserving | |
| the original sequence of unique content items. | |
| Args: | |
| segments: List of Segment objects to group. | |
| Returns: | |
| OrderedDict mapping (type, label) tuples to lists of unique content strings. | |
| """ | |
| grouped: collections.OrderedDict[ | |
| tuple[ReportSectionType, str | None], list[str] | |
| ] = collections.OrderedDict() | |
| for seg in segments: | |
| key = (seg.type, seg.label) | |
| grouped.setdefault(key, []) | |
| if seg.content not in grouped[key]: | |
| grouped[key].append(seg.content.strip()) | |
| return grouped | |
| def _render_prefix_sections( | |
| self, | |
| grouped: collections.OrderedDict[ | |
| tuple[ReportSectionType, str | None], list[str] | |
| ], | |
| segments: list[Segment], | |
| formatted_parts: list[str], | |
| ) -> None: | |
| """Renders PREFIX sections with appropriate headers.""" | |
| add = formatted_parts.append | |
| def blank() -> None: | |
| formatted_parts.append("") | |
| structured_prefix_exists = any( | |
| seg.type == ReportSectionType.PREFIX | |
| and seg.label | |
| and seg.label.lower() != PREFIX_LABEL | |
| for seg in segments | |
| ) | |
| if structured_prefix_exists: | |
| for (stype, label), contents in grouped.items(): | |
| if stype is not ReportSectionType.PREFIX: | |
| continue | |
| if label and label.lower() == EXAMINATION_LABEL: | |
| add(EXAMINATION_HEADER) | |
| blank() | |
| for c in contents: | |
| stripped = self._strip_exam_prefix(c) | |
| if stripped: | |
| add(stripped) | |
| blank() | |
| elif label and label.lower() != PREFIX_LABEL: | |
| for c in contents: | |
| if c: | |
| add(c) | |
| blank() | |
| else: | |
| for c in contents: | |
| if c: | |
| add(c) | |
| blank() | |
| else: | |
| plain_prefix = [] | |
| for (stype, _), contents in grouped.items(): | |
| if stype is ReportSectionType.PREFIX: | |
| plain_prefix.extend(contents) | |
| if plain_prefix: | |
| add("\n\n".join(plain_prefix).rstrip()) | |
| def _render_body_sections( | |
| self, | |
| grouped: collections.OrderedDict[ | |
| tuple[ReportSectionType, str | None], list[str] | |
| ], | |
| formatted_parts: list[str], | |
| ) -> None: | |
| """Renders BODY (FINDINGS) sections.""" | |
| add = formatted_parts.append | |
| def blank() -> None: | |
| formatted_parts.append("") | |
| body_items = [ | |
| (k, v) for k, v in grouped.items() if k[0] is ReportSectionType.BODY | |
| ] | |
| if body_items: | |
| if formatted_parts: | |
| blank() | |
| add(FINDINGS_HEADER) | |
| blank() | |
| for (_, label), contents in body_items: | |
| combined = " ".join(contents).strip() | |
| if combined: | |
| add(f"{label}: {combined}") | |
| blank() | |
| def _render_suffix_sections( | |
| self, | |
| grouped: collections.OrderedDict[ | |
| tuple[ReportSectionType, str | None], list[str] | |
| ], | |
| formatted_parts: list[str], | |
| ) -> None: | |
| """Renders SUFFIX (IMPRESSION) sections.""" | |
| add = formatted_parts.append | |
| def blank() -> None: | |
| formatted_parts.append("") | |
| suffix_items = [ | |
| (k, v) for k, v in grouped.items() if k[0] is ReportSectionType.SUFFIX | |
| ] | |
| if suffix_items: | |
| if formatted_parts and formatted_parts[-1].strip(): | |
| blank() | |
| add(IMPRESSION_HEADER) | |
| blank() | |
| suffix_block = "\n".join( | |
| itertools.chain.from_iterable(v for _, v in suffix_items) | |
| ).rstrip() | |
| add(suffix_block) | |
| def _organize_segments_by_label(self, segments: list[Segment]) -> list[Segment]: | |
| """Organizes segments into the correct order for presentation. | |
| Orders segments by section type (prefix → body → suffix), groups | |
| body segments by label while preserving original appearance order, | |
| and maintains extraction order for segments with the same label. | |
| Args: | |
| segments: List of Segment objects to organize. | |
| Returns: | |
| List of segments in proper presentation order. | |
| """ | |
| prefix_segments = [ | |
| segment for segment in segments if segment.type == ReportSectionType.PREFIX | |
| ] | |
| body_segments = [ | |
| segment for segment in segments if segment.type == ReportSectionType.BODY | |
| ] | |
| suffix_segments = [ | |
| segment for segment in segments if segment.type == ReportSectionType.SUFFIX | |
| ] | |
| body_segments_by_label: dict[str, list[Segment]] = {} | |
| labels_in_order: list[str] = [] | |
| for segment in body_segments: | |
| if segment.label: | |
| if segment.label not in body_segments_by_label: | |
| body_segments_by_label[segment.label] = [] | |
| labels_in_order.append(segment.label) | |
| body_segments_by_label[segment.label].append(segment) | |
| organized_segments = [] | |
| organized_segments.extend(prefix_segments) | |
| for label in labels_in_order: | |
| organized_segments.extend(body_segments_by_label[label]) | |
| organized_segments.extend(suffix_segments) | |
| return organized_segments | |
| def _strip_exam_prefix(self, text: str) -> str: | |
| """Removes common examination prefixes from a string.""" | |
| upper = text.upper() | |
| for prefix in EXAM_PREFIXES: | |
| if upper.startswith(prefix): | |
| return text[len(prefix) :].lstrip() | |
| return text.strip() | |