Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| """Flask web application for radiology report structuring using Gemini models. | |
| This module provides a web API that structures radiology reports into | |
| semantic sections using LangExtract and Google's Gemini language models. | |
| The application supports caching, multiple model configurations, and | |
| provides both a web interface and REST API endpoints. | |
| Typical usage example: | |
| # Set environment variables | |
| export KEY=your_gemini_api_key_here | |
| export MODEL_ID=gemini-2.5-flash | |
| # Run the application | |
| python app.py | |
| """ | |
| import logging | |
| import os | |
| import shutil | |
| import tempfile | |
| import time | |
| import json | |
| import hashlib | |
| from flask import Flask, jsonify, render_template, request | |
| from flask_limiter import Limiter | |
| from flask_limiter.util import get_remote_address | |
| from cache_manager import CacheManager | |
| from sanitize import preprocess_report | |
| from social_sharing import SocialSharingConfig | |
| from structure_report import RadiologyReportStructurer, ResponseDict | |
| # Configuration constants | |
| MAX_INPUT_LENGTH = 3000 | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class Model: | |
| """Manages RadiologyReportStructurer instances for different Gemini model IDs. | |
| This class handles initialization, caching, and coordination | |
| of structurer instances for various model configurations, ensuring | |
| efficient resource usage and consistent API key management. | |
| """ | |
| def __init__(self): | |
| """Initializes the Model manager with default structurer. | |
| Sets up the Gemini API key from environment variables | |
| and creates a default structurer instance for the configured model. | |
| Raises: | |
| ValueError: If the KEY environment variable is not set. | |
| """ | |
| self.gemini_api_key = os.environ.get("KEY") | |
| if not self.gemini_api_key: | |
| logger.error("KEY environment variable not set.") | |
| raise ValueError("KEY environment variable not set.") | |
| self._structurers: dict[str, RadiologyReportStructurer] = {} | |
| default_model_id = os.environ.get("MODEL_ID", "gemini-2.5-flash") | |
| self._structurers[default_model_id] = RadiologyReportStructurer( | |
| api_key=self.gemini_api_key, | |
| model_id=default_model_id, | |
| ) | |
| logger.info( | |
| f"RadExtract ready [Worker {os.getpid()}] with model: {default_model_id}" | |
| ) | |
| def _get_structurer(self, model_id: str) -> RadiologyReportStructurer: | |
| """Returns a cached or newly created structurer for the given model ID. | |
| Args: | |
| model_id: Identifier for the specific model configuration. | |
| Returns: | |
| RadiologyReportStructurer instance for the specified model. | |
| """ | |
| if model_id not in self._structurers: | |
| logger.info(f"Creating structurer for model: {model_id}") | |
| self._structurers[model_id] = RadiologyReportStructurer( | |
| api_key=self.gemini_api_key, | |
| model_id=model_id, | |
| ) | |
| return self._structurers[model_id] | |
| def predict(self, data: str, model_id: str) -> ResponseDict: | |
| """Processes prediction request using the specified model. | |
| Args: | |
| data: Input text data to be processed. | |
| model_id: Identifier for the model to use for processing. | |
| Returns: | |
| Dictionary containing the structured prediction results. | |
| """ | |
| logger.info(f"Processing prediction with model: {model_id}") | |
| structurer = self._get_structurer(model_id) | |
| result = structurer.predict(data) | |
| logger.info(f"Result preview: {str(result)[:500]}...") | |
| return result | |
| model = Model() | |
| # Copy prebuilt cache to writable location if it exists | |
| def setup_cache(): | |
| """Sets up the cache directory and copies prebuilt cache files. | |
| Creates a writable cache directory in /tmp and copies any existing | |
| prebuilt cache files to ensure the latest version is available. | |
| Returns: | |
| Path to the configured cache directory. | |
| """ | |
| cache_dir = tempfile.gettempdir() + "/cache" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| source_cache = "cache/sample_cache.json" | |
| target_cache = os.path.join(cache_dir, "sample_cache.json") | |
| if os.path.exists(source_cache) and not os.path.exists(target_cache): | |
| shutil.copy2(source_cache, target_cache) | |
| logger.info(f"Initialized cache with {os.path.getsize(target_cache)} bytes") | |
| return cache_dir | |
| cache_dir = setup_cache() | |
| cache_manager = CacheManager(cache_dir=cache_dir) | |
| app = Flask( | |
| __name__, | |
| static_url_path="/static", | |
| static_folder="static", | |
| template_folder="templates", | |
| ) | |
| # Initialize rate limiter | |
| limiter = Limiter( | |
| get_remote_address, | |
| app=app, | |
| default_limits=[ | |
| os.environ.get("RATE_LIMIT_DAY", "200 per day"), | |
| os.environ.get("RATE_LIMIT_HOUR", "50 per hour"), | |
| ], | |
| storage_uri="memory://", | |
| ) | |
| def index(): | |
| """Renders the main application interface. | |
| Returns: | |
| Rendered HTML template for the application index page. | |
| """ | |
| # Get social sharing context | |
| social_context = SocialSharingConfig.get_sharing_context(request.url_root) | |
| return render_template("index.html", **social_context) | |
| def cache_stats(): | |
| """Returns cache performance statistics. | |
| Returns: | |
| JSON response containing cache usage and performance statistics. | |
| """ | |
| return jsonify(cache_manager.get_cache_stats()) | |
| def predict(): | |
| """Processes radiology report text and returns structured results. | |
| Accepts raw text via POST request body with optional headers | |
| for caching, sample identification, and model selection. Supports | |
| both cached and real-time processing modes. | |
| Returns: | |
| JSON response containing structured report segments, annotations, | |
| and formatted text. Includes cache status when applicable. | |
| Raises: | |
| 500: If processing fails due to invalid input or model errors. | |
| """ | |
| start_time = time.time() | |
| try: | |
| data = request.get_data(as_text=True) | |
| # Validate input to ensure it meets API requirements | |
| if not data or not data.strip(): | |
| return ( | |
| jsonify( | |
| { | |
| "error": "Empty input", | |
| "message": "Input text is required", | |
| "max_length": MAX_INPUT_LENGTH, | |
| } | |
| ), | |
| 400, | |
| ) | |
| if len(data) > MAX_INPUT_LENGTH: | |
| return ( | |
| jsonify( | |
| { | |
| "error": "Input too long", | |
| "message": f"Input length ({len(data)} characters) exceeds maximum allowed length of {MAX_INPUT_LENGTH} characters", | |
| "max_length": MAX_INPUT_LENGTH, | |
| } | |
| ), | |
| 400, | |
| ) | |
| use_cache = request.headers.get("X-Use-Cache", "true").lower() == "true" | |
| sample_id = request.headers.get("X-Sample-ID") | |
| model_id = request.headers.get( | |
| "X-Model-ID", os.environ.get("MODEL_ID", "gemini-2.5-flash") | |
| ) | |
| processed_data = preprocess_report(data) | |
| if use_cache: | |
| cached_result = cache_manager.get_cached_result(processed_data, sample_id) | |
| if cached_result: | |
| req_id = hashlib.md5( | |
| f"{request.remote_addr}{int(time.time()/3600)}".encode() | |
| ).hexdigest()[:8] | |
| logger.info( | |
| f"🟢 CACHE HIT [Req {req_id}] [Worker {os.getpid()}] - Returning cached result (no API call)" | |
| ) | |
| return jsonify({"from_cache": True, **cached_result}) | |
| try: | |
| req_id = hashlib.md5( | |
| f"{request.remote_addr}{int(time.time()/3600)}".encode() | |
| ).hexdigest()[:8] | |
| logger.info( | |
| f"🔴 API CALL [Req {req_id}] [Worker {os.getpid()}] - Processing with Gemini model: {model_id}" | |
| ) | |
| result = model.predict(processed_data, model_id=model_id) | |
| # Cache if it's a predefined sample | |
| if use_cache and sample_id: | |
| cache_manager.cache_result(processed_data, result, sample_id) | |
| result["sanitized_input"] = processed_data | |
| return jsonify(result) | |
| except TypeError as te: | |
| error_msg = str(te) | |
| logger.error(f"TypeError in prediction: {error_msg}", exc_info=True) | |
| return ( | |
| jsonify({"error": "Processing error. Please try a different input."}), | |
| 500, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}", exc_info=True) | |
| return jsonify({"error": str(e)}), 500 | |
| def ratelimit_handler(e): | |
| """Handle rate limit exceeded errors.""" | |
| return ( | |
| jsonify( | |
| { | |
| "error": "Rate limit exceeded. Please try again later.", | |
| "message": str(e.description), | |
| } | |
| ), | |
| 429, | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("Starting development server") | |
| app.run(host="0.0.0.0", port=7870, debug=True) | |