# -*- coding: utf-8 -*- import os import base64 import io import logging # For better logging # Import specific handlers and formatter from logging.handlers import RotatingFileHandler import traceback # For detailed exception logging from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS # To handle Cross-Origin requests from your frontend import torch import cv2 import numpy as np import yaml from torchvision import transforms from transformers import SegformerForSemanticSegmentation from omegaconf import OmegaConf # Import OmegaConf itself import torch.nn.functional as F from werkzeug.utils import secure_filename # For safer filenames # --- Configuration --- # Use absolute paths for robustness BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Directory where this script is running # >>> Point this to your actual config file <<< CONFIG_PATH = os.path.join(BASE_DIR, "config/config.yaml") # Assuming config.yaml is in the same dir # >>> Point this to your actual checkpoint file <<< CHECKPOINT_PATH = "ckpt_000-vloss_0.4685_vf1_0.6469.ckpt" UPLOAD_FOLDER = os.path.join(BASE_DIR, 'uploads') RESULT_FOLDER = os.path.join(BASE_DIR, 'results') LOG_FILE_PATH = os.path.join(BASE_DIR, 'flask_app.log') # Define log file path ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'bmp', 'tif', 'tiff'} # --- Logging Setup --- # Clear existing handlers from the root logger to avoid duplicates on reload logging.getLogger().handlers.clear() # Create formatter log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s') # Create Console Handler console_handler = logging.StreamHandler() console_handler.setFormatter(log_formatter) # Create File Handler (using RotatingFileHandler for log rotation) file_handler = RotatingFileHandler(LOG_FILE_PATH, maxBytes=5*1024*1024, backupCount=3) file_handler.setFormatter(log_formatter) # Get the root logger and add handlers logger = logging.getLogger() logger.setLevel(logging.INFO) # Set minimum level for the logger (e.g., INFO, DEBUG) logger.addHandler(console_handler) logger.addHandler(file_handler) # --- Ensure upload and result directories exist --- try: os.makedirs(UPLOAD_FOLDER, exist_ok=True) os.makedirs(RESULT_FOLDER, exist_ok=True) logger.info(f"Ensured directories exist: {UPLOAD_FOLDER}, {RESULT_FOLDER}") except OSError as e: logger.error(f"Error creating directories: {e}") exit(1) # Exit if we can't create essential folders # --- Load Config --- config = None try: # Load the YAML file using OmegaConf config = OmegaConf.load(CONFIG_PATH) # Note: We don't need OmegaConf.create() if loading directly from file logger.info(f"Configuration loaded successfully from: {CONFIG_PATH}") # Log some key values to confirm loading logger.info(f"Config check: num_classes={config.data.num_classes}, model_name={config.training.model_name}") except FileNotFoundError: logger.error(f"Configuration file not found: {CONFIG_PATH}") exit(1) except Exception as e: # Catch broader errors during loading/parsing logger.error(f"Error loading or parsing configuration file '{CONFIG_PATH}': {e}") logger.error(traceback.format_exc()) exit(1) # --- Model Definition --- class InferenceModel(torch.nn.Module): def __init__(self, model_config): # Use local name 'model_config' super().__init__() try: # Access config values needed for model init model_name = model_config.training.model_name num_classes = model_config.data.num_classes logger.info(f"Initializing SegformerForSemanticSegmentation with model='{model_name}', num_labels={num_classes}") self.model = SegformerForSemanticSegmentation.from_pretrained( model_name, num_labels=num_classes, ignore_mismatched_sizes=True # Important if fine-tuning head size differs ) logger.info("Segformer model part initialized.") except AttributeError as ae: logger.error(f"Config error during model init: Missing key? {ae}") logger.error(f"Check if 'training.model_name' and 'data.num_classes' exist in {CONFIG_PATH}") raise # Re-raise error to stop execution except Exception as e: logger.error(f"Error initializing Segformer model from Hugging Face: {e}") logger.error(traceback.format_exc()) raise # Re-raise error to stop execution def forward(self, x): # Expects pixel_values as input outputs = self.model(pixel_values=x, return_dict=True) # Upsample logits to original input size logits = F.interpolate( outputs.logits, size=x.shape[-2:], # Get H, W from input tensor x mode="bilinear", align_corners=False ) return logits # --- Utility Functions --- def num_to_rgb(num_arr, color_map_dict): """Converts a label mask (numpy array) to an RGB color mask.""" single_layer = np.squeeze(num_arr) output = np.zeros(num_arr.shape[:2] + (3,), dtype=np.uint8) # Initialize with uint8 zeros # Expects color_map_dict to be a standard Python dict {int_label: [R, G, B]} if not isinstance(color_map_dict, dict): logger.error(f"Invalid color_map provided to num_to_rgb: {type(color_map_dict)}. Expected dict.") return np.float32(output) / 255.0 # Return black float image unique_labels = np.unique(single_layer) for k in unique_labels: label_key = int(k) # Ensure key is standard int for lookup if label_key in color_map_dict: # Assign color, ensure color value is appropriate (e.g., list/tuple of 3 ints) color = color_map_dict[label_key] if isinstance(color, (list, tuple)) and len(color) == 3: output[single_layer == k] = color else: logger.warning(f"Invalid color format for label {label_key} in color map: {color}. Skipping.") else: if label_key != 0: # Often 0 is background, might not be in map logger.warning(f"Label Key {label_key} found in mask but not in provided color map.") # Default color (e.g., black) is already set by np.zeros return np.float32(output) / 255.0 # Return float32 RGB image [0, 1] def denormalize(tensor, mean, std): """Denormalizes a torch tensor (CHW format).""" # Expects standard Python lists/tuples for mean/std if not isinstance(mean, (list, tuple)) or not isinstance(std, (list, tuple)): logger.error(f"Mean ({type(mean)}) or std ({type(std)}) are not lists/tuples in denormalize.") return None # Input tensor expected shape: Batch, Channel, Height, Width (e.g., from dataloader or transform) if tensor.dim() != 4: # B C H W logger.error(f"Unexpected tensor dimension {tensor.dim()} in denormalize. Expected 4 (BCHW).") # Attempt to add batch dim if it's 3D (CHW) if tensor.dim() == 3: logger.warning("Denormalize received 3D tensor, adding batch dimension.") tensor = tensor.unsqueeze(0) else: return None # Cannot handle other dims num_channels = tensor.shape[1] # Channel dimension if len(mean) != num_channels or len(std) != num_channels: logger.error(f"Mean/std length ({len(mean)}/{len(std)}) mismatch with tensor channels ({num_channels})") return None # Clone to avoid modifying original tensor tensor = tensor.clone().cpu() # Work on CPU copy # Denormalize each channel for c in range(num_channels): tensor[:, c, :, :] = tensor[:, c, :, :] * std[c] + mean[c] # Apply to all items in batch # Clamp values, remove batch dimension, permute to HWC for display/saving # Assumes we are processing one image at a time here for inference result denormalized_img_tensor = torch.clamp(tensor.squeeze(0), 0, 1).permute(1, 2, 0) return denormalized_img_tensor.numpy() # Convert to numpy array (HWC, float32, [0,1]) # --- Load Model (Corrected Version) --- def load_trained_model(checkpoint_path, model_config): """Loads the trained model from a checkpoint, handling potential key mismatches.""" try: model_instance = InferenceModel(model_config) # Create model structure logger.info(f"Attempting to load checkpoint from: {checkpoint_path}") if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint file not found at specified path: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) logger.info(f"Checkpoint loaded into memory. Type: {type(checkpoint)}") # Extract the state dictionary - flexible based on common saving patterns if isinstance(checkpoint, dict) and "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] logger.info("Using 'state_dict' key from checkpoint.") elif isinstance(checkpoint, dict): # Assume the dict *is* the state_dict if 'state_dict' key is absent state_dict = checkpoint logger.info("Using checkpoint dictionary directly as state_dict (no 'state_dict' key found).") else: # Could be the model itself was saved directly (less common with frameworks) logger.warning(f"Checkpoint is not a dictionary. Attempting to load directly into model (less common). Type was: {type(checkpoint)}") # This path might need adjustment based on how the model was saved if not a state_dict try: model_instance.load_state_dict(checkpoint) # Try loading directly logger.info("Loaded state_dict directly from checkpoint object.") model_instance.eval() return model_instance except Exception as e: logger.error(f"Failed to load state_dict directly from checkpoint object: {e}") return None # Failed direct load # --- Key Prefix Correction Logic --- target_keys = set(model_instance.state_dict().keys()) loaded_keys = set(state_dict.keys()) if not loaded_keys: logger.warning("Loaded state_dict is empty!"); return None # Check if state_dict is empty first_loaded_key = next(iter(loaded_keys), None) first_target_key = next(iter(target_keys), None) corrected_state_dict = {} prefix_added = False # Check if prefix 'model.' needs to be ADDED to loaded keys if first_loaded_key and not first_loaded_key.startswith('model.') and \ first_target_key and first_target_key.startswith('model.'): logger.warning("Checkpoint keys missing 'model.' prefix. Attempting to add it.") prefix_added = True keys_not_prefixed_properly = [] for k, v in state_dict.items(): new_key = f"model.{k}" if new_key in target_keys: corrected_state_dict[new_key] = v else: keys_not_prefixed_properly.append(k); corrected_state_dict[k] = v # Keep original if prefixed version not wanted if keys_not_prefixed_properly: logger.warning(f"Keys kept without prefix (target doesn't expect): {keys_not_prefixed_properly}") logger.info("Finished attempting prefix addition.") # Check if prefix 'model.' needs to be REMOVED from loaded keys elif first_loaded_key and first_loaded_key.startswith('model.') and \ first_target_key and not first_target_key.startswith('model.'): logger.warning("Checkpoint keys HAVE 'model.' prefix, but target model DOES NOT. Attempting to remove it.") prefix_added = False # Indicate we removed prefix, not added keys_not_stripped_properly = [] for k, v in state_dict.items(): if k.startswith('model.'): new_key = k.partition('model.')[2] # Get part after 'model.' if new_key in target_keys: corrected_state_dict[new_key] = v else: keys_not_stripped_properly.append(k); corrected_state_dict[k] = v # Keep original if stripped version not wanted else: # Keep keys that didn't have prefix anyway corrected_state_dict[k] = v if keys_not_stripped_properly: logger.warning(f"Keys kept with prefix (target doesn't expect stripped): {keys_not_stripped_properly}") logger.info("Finished attempting prefix removal.") else: logger.info("State dict keys seem to have correct prefix structure (or other mismatch). Using as is.") corrected_state_dict = state_dict # Use the original dict # --- Load the State Dictionary --- logger.info("Attempting to load state_dict with strict=False for checking...") missing_keys, unexpected_keys = model_instance.load_state_dict(corrected_state_dict, strict=False) # Report detailed findings final_msg = [] is_load_successful = True if missing_keys: final_msg.append(f"MISSING keys in checkpoint: {missing_keys}") logger.error("CRITICAL FAILURE: Model is missing required keys.") is_load_successful = False if unexpected_keys: final_msg.append(f"UNEXPECTED keys in checkpoint (exist in file but not in model): {unexpected_keys}") # Decide if unexpected keys are acceptable acceptable_unexpected = [k for k in unexpected_keys if k.endswith('num_batches_tracked')] unacceptable_unexpected = [k for k in unexpected_keys if not k.endswith('num_batches_tracked')] if unacceptable_unexpected: logger.error(f"CRITICAL FAILURE: Model received unacceptable unexpected keys: {unacceptable_unexpected}") is_load_successful = False elif acceptable_unexpected: logger.warning(f"Ignoring acceptable unexpected keys: {acceptable_unexpected}") if not is_load_successful: logger.error(f"State dict loading failed. Issues: {'; '.join(final_msg)}") return None # Failed to load properly logger.info(f"State dictionary loaded successfully. Issues (if any): {final_msg if final_msg else 'None'}") model_instance.eval() # Set to evaluation mode logger.info(f"Model loading process complete for {checkpoint_path}") return model_instance except FileNotFoundError as fnf_error: logger.error(f"{fnf_error}") # Log the specific FileNotFoundError message return None except Exception as e: logger.error(f"Unexpected error during model loading: {e}") logger.error(traceback.format_exc()) # Log full traceback return None # --- Determine device & Load Model Globally --- device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # Load the model using the global config object model = load_trained_model(CHECKPOINT_PATH, config) # Pass the loaded config if model is None: logger.critical("CRITICAL: Failed to load model. Application cannot continue.") exit(1) # Critical error, stop the application model.to(device) # Move model to the appropriate device # --- Inference Pipeline (Corrected Config Handling) --- def run_inference_on_bytes(image_bytes, inference_model, model_config, device): """Runs inference on image bytes, returns denormalized image, color mask, and overlay.""" try: nparr = np.frombuffer(image_bytes, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: logger.error("Failed cv2.imdecode."); return None, None, None img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) logger.debug("Image decoded and converted to RGB.") # --- Preprocessing (with config conversion) --- # Check necessary config attributes exist before conversion attempt required_data_keys = ['image_size', 'mean', 'std', 'num_classes'] for key in required_data_keys: if not OmegaConf.select(model_config, f'data.{key}', default=None): logger.error(f"Config missing required data field: data.{key}") return None, None, None if not OmegaConf.select(model_config, 'id2color', default=None): logger.error("Config missing required field: id2color") return None, None, None if not OmegaConf.select(model_config, 'training.model_name', default=None): logger.error("Config missing required field: training.model_name") return None, None, None try: # Convert OmegaConf structures to standard Python types using OmegaConf.to_container # resolve=True handles variable interpolation (like ${data.base_dir}) if used in relevant fields img_size = tuple(OmegaConf.to_container(model_config.data.image_size, resolve=True)) mean = list(OmegaConf.to_container(model_config.data.mean, resolve=True)) std = list(OmegaConf.to_container(model_config.data.std, resolve=True)) # Ensure keys in id2color are standard integers id2color_map = {int(k): v for k, v in OmegaConf.to_container(model_config.id2color, resolve=True).items()} num_classes = int(model_config.data.num_classes) # Ensure int logger.debug(f"Converted config values: size={img_size}, mean={mean}, std={std}, id2color keys={list(id2color_map.keys())}, num_classes={num_classes}") # Basic validation after conversion if not isinstance(mean, list) or not isinstance(std, list) or not isinstance(id2color_map, dict): raise TypeError("Config values did not convert to list/dict.") if len(mean) != 3 or len(std) != 3: raise ValueError(f"Incorrect mean/std length. Expected 3.") # Assuming 3 color channels if len(img_size) != 2: raise ValueError(f"Incorrect image_size length. Expected 2 (H, W).") except Exception as e: logger.error(f"Error processing/converting configuration values: {e}") logger.error(traceback.format_exc()) return None, None, None # Define the image transformation pipeline transform = transforms.Compose([ transforms.ToTensor(), # HWC [0,255] numpy -> CHW [0,1] torch transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR), # Use converted tuple size, BILINEAR is common for images before model transforms.Normalize(mean=mean, std=std) # Use converted lists ]) logger.debug(f"Image transform applied for size {img_size}.") input_tensor = transform(img_rgb).unsqueeze(0).to(device) # Add batch dim (B=1), move to device logger.debug(f"Input tensor created with shape: {input_tensor.shape}") # Should be [1, 3, H, W] # --- Run Prediction --- with torch.no_grad(): logits = inference_model(input_tensor) # Expect [B, C, H, W] logits logger.debug(f"Logits received with shape: {logits.shape}") # Check logits shape again after potential upsampling in model forward if logits.dim() != 4 or logits.shape[1] != num_classes: logger.error(f"Unexpected final logits shape or class number: {logits.shape}. Expected B x {num_classes} x H x W.") return None, None, None # Argmax along class dimension (C), remove batch dim, move to CPU, convert type pred_mask = logits.argmax(1).squeeze(0).cpu().numpy().astype(np.uint8) # H W, uint8 logger.debug(f"Prediction mask generated with shape: {pred_mask.shape}") # Should be [H, W] # --- Post-processing --- color_mask = num_to_rgb(pred_mask, id2color_map) # Use converted map if color_mask is None: logger.error("num_to_rgb failed."); return None, None, None logger.debug("Color mask generated.") # Denormalize the *input tensor* for overlay display denorm_img = denormalize(input_tensor, mean, std) # Use converted mean/std if denorm_img is None: logger.error("denormalize failed."); return None, None, None logger.debug("Input tensor denormalized for overlay.") # HWC, float32, [0,1] # --- Create Overlay --- # Ensure shapes match before blending (resize color mask to match denorm_img) if denorm_img.shape[:2] != color_mask.shape[:2]: logger.warning(f"Denorm img shape {denorm_img.shape[:2]} != Color mask shape {color_mask.shape[:2]}. Resizing color mask using INTER_NEAREST.") # Resize color_mask (HWC float32) to match denorm_img (HWC float32) color_mask = cv2.resize(color_mask, (denorm_img.shape[1], denorm_img.shape[0]), interpolation=cv2.INTER_NEAREST) # Use INTER_NEAREST for label masks # Blend images: Original (denorm_img) * alpha + Mask (color_mask) * beta + gamma overlay = cv2.addWeighted(denorm_img, 0.7, color_mask, 0.3, 0) logger.debug("Overlay created using cv2.addWeighted.") # overlay is HWC, float32, [0, 1], RGB return denorm_img, color_mask, overlay except Exception as e: logger.error(f"Exception during inference pipeline for image: {e}") logger.error(traceback.format_exc()) return None, None, None # --- Flask App --- app = Flask(__name__) CORS(app) # Allow all origins for API and Result routes resources={r"/api/*": {"origins": "*"}, r"/Result/*": {"origins": "*"}} logger.info("Flask app created and CORS enabled.") def allowed_file(filename): """Checks if the filename has an allowed extension.""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS # --- API Endpoints --- @app.route('/api/analyze', methods=['POST']) def analyze_image(): """Receives Base64 image, runs inference, saves original and overlay.""" global model, config, device # Access global vars endpoint_log_prefix = "[POST /api/analyze]" logger.info(f"{endpoint_log_prefix} Received request.") # --- Basic Checks --- if model is None: logger.error(f"{endpoint_log_prefix} Model not loaded."); return jsonify({"success": False, "message": "Model not loaded"}), 500 if not request.is_json: logger.warning(f"{endpoint_log_prefix} Not JSON."); return jsonify({"success": False, "message": "Request must be JSON"}), 400 data = request.get_json() if not data or 'image' not in data or 'filename' not in data: logger.warning(f"{endpoint_log_prefix} Missing image/filename in JSON body. Data received: {data}") return jsonify({"success": False, "message": "Missing 'image' (base64) or 'filename' in JSON body"}), 400 base64_image_data = data['image']; original_filename = data['filename'] logger.info(f"{endpoint_log_prefix} Original filename from request: '{original_filename}'") safe_original_filename = secure_filename(original_filename) # Sanitize if not safe_original_filename or not allowed_file(safe_original_filename): logger.warning(f"{endpoint_log_prefix} Invalid/disallowed filename after sanitization: '{safe_original_filename}' from '{original_filename}'") return jsonify({"success": False, "message": "Invalid or disallowed filename/extension"}), 400 logger.info(f"{endpoint_log_prefix} Sanitized filename for saving/processing: '{safe_original_filename}'") try: # --- Decode Base64 --- if ',' in base64_image_data: header, encoded = base64_image_data.split(',', 1) else: encoded = base64_image_data # Assume no header image_bytes = base64.b64decode(encoded) logger.info(f"{endpoint_log_prefix} Base64 image decoded ({len(image_bytes)} bytes).") # --- Save Original Image --- original_path = os.path.join(UPLOAD_FOLDER, safe_original_filename) try: with open(original_path, "wb") as f: f.write(image_bytes) logger.info(f"{endpoint_log_prefix} Original image saved to: '{original_path}'") except Exception as e: logger.error(f"{endpoint_log_prefix} Failed to save original image to '{original_path}': {e}") return jsonify({"success": False, "message": "Failed to save uploaded image on server"}), 500 # --- Run Inference --- logger.info(f"{endpoint_log_prefix} Starting inference for '{safe_original_filename}'...") # Pass the global config object here denorm_img, color_mask, overlay = run_inference_on_bytes(image_bytes, model, config, device) if overlay is None: # Check if inference failed logger.error(f"{endpoint_log_prefix} Inference pipeline returned None for '{safe_original_filename}'.") return jsonify({"success": False, "message": "Inference process failed on server"}), 500 logger.info(f"{endpoint_log_prefix} Inference completed successfully for '{safe_original_filename}'.") # --- Save Overlay Image --- name_part, ext = os.path.splitext(safe_original_filename) # Create consistent overlay filename (crucial for toggle endpoint) overlay_filename = f"analyzed_{name_part}{ext}" overlay_path = os.path.join(RESULT_FOLDER, overlay_filename) logger.info(f"{endpoint_log_prefix} Determined overlay filename: '{overlay_filename}' -> path: '{overlay_path}'") # Convert overlay (float32 HWC RGB [0,1]) to uint8 HWC BGR [0,255] for cv2.imwrite try: overlay_to_save_uint8 = (overlay * 255).astype(np.uint8) overlay_to_save_bgr = cv2.cvtColor(overlay_to_save_uint8, cv2.COLOR_RGB2BGR) save_success = cv2.imwrite(overlay_path, overlay_to_save_bgr) if not save_success: raise IOError(f"cv2.imwrite failed to save the overlay image to {overlay_path}") logger.info(f"{endpoint_log_prefix} Overlay image saved successfully to: '{overlay_path}'") except Exception as e: logger.error(f"{endpoint_log_prefix} Failed to convert or save overlay image to '{overlay_path}': {e}") logger.error(traceback.format_exc()) return jsonify({"success": False, "message": "Failed to save analysis result image"}), 500 # --- Success Response --- logger.info(f"{endpoint_log_prefix} Analysis successful for '{safe_original_filename}'. Returning success.") return jsonify({ "success": True, "message": "Analysis complete", # Optionally return relative paths for info, client mainly needs overlay_filename "paths": {"original": os.path.relpath(original_path, BASE_DIR), "overlay": os.path.relpath(overlay_path, BASE_DIR)}, "overlay_filename": overlay_filename # Return the *exact* filename saved }), 200 except base64.binascii.Error as e: logger.error(f"{endpoint_log_prefix} Invalid Base64 data received: {e}") return jsonify({"success": False, "message": "Invalid Base64 image data received"}), 400 except Exception as e: logger.error(f"{endpoint_log_prefix} Unexpected error during analysis request processing: {e}") logger.error(traceback.format_exc()) return jsonify({"success": False, "message": "Internal server error during analysis processing"}), 500 @app.route('/api/toggle-image', methods=['GET']) def get_analysis_path(): """Checks if the analyzed version of a given original filename exists.""" endpoint_log_prefix = "[GET /api/toggle-image]" logger.info(f"{endpoint_log_prefix} Received request.") logger.info(f"{endpoint_log_prefix} Full request URL: {request.url}") logger.info(f"{endpoint_log_prefix} Request Query Args: {request.args}") # Log received args original_filename = request.args.get('filename') # Get filename from ?filename=... if not original_filename: logger.warning(f"{endpoint_log_prefix} Missing 'filename' query parameter.") return jsonify({"message": "Missing 'filename' query parameter"}), 400 logger.info(f"{endpoint_log_prefix} Original filename received from query: '{original_filename}'") safe_original_filename = secure_filename(original_filename) # Sanitize if not safe_original_filename: logger.warning(f"{endpoint_log_prefix} Invalid filename after sanitization: '{safe_original_filename}' from '{original_filename}'") return jsonify({"message": "Invalid filename format"}), 400 logger.info(f"{endpoint_log_prefix} Sanitized filename for lookup: '{safe_original_filename}'") # --- Construct Expected Overlay Path (MUST match /analyze logic) --- name_part, ext = os.path.splitext(safe_original_filename) expected_overlay_filename = f"analyzed_{name_part}{ext}" expected_overlay_path = os.path.join(RESULT_FOLDER, expected_overlay_filename) logger.info(f"{endpoint_log_prefix} Expecting overlay file at: '{expected_overlay_path}'") # --- Check if File Exists --- if os.path.exists(expected_overlay_path): logger.info(f"{endpoint_log_prefix} Found analysis result file: '{expected_overlay_filename}'") # Return just the filename, client constructs the full /Result/ URL return jsonify({"filepath": expected_overlay_filename}), 200 else: # Explicitly log the path that was checked and not found logger.warning(f"{endpoint_log_prefix} Analysis result file NOT FOUND at checked path: '{expected_overlay_path}'") # Return 404 Not Found status code return jsonify({"message": f"Analysis result not found for '{original_filename}'"}), 404 @app.route('/Result/') def serve_result_image(filename): """Serves images from the RESULT_FOLDER.""" endpoint_log_prefix = "[GET /Result]" # Sanitize filename received in URL path for security safe_filename = secure_filename(filename) if safe_filename != filename: # Log if the requested filename was changed by sanitization logger.warning(f"{endpoint_log_prefix} Requested filename '{filename}' was sanitized to '{safe_filename}'. Serving sanitized version.") logger.info(f"{endpoint_log_prefix} Attempting to serve file: '{safe_filename}' from directory: '{RESULT_FOLDER}'") try: # Use Flask's send_from_directory - safer than manual path joining # as_attachment=False means display in browser if possible return send_from_directory(RESULT_FOLDER, safe_filename, as_attachment=False) except FileNotFoundError: # Log the specific file that was not found logger.error(f"{endpoint_log_prefix} Requested file not found in result folder: '{safe_filename}'") # Return 404 Not Found return jsonify({"message": "Requested analysis image not found"}), 404 except Exception as e: # Catch other potential errors (e.g., permission issues) logger.error(f"{endpoint_log_prefix} Error serving file '{safe_filename}': {e}") logger.error(traceback.format_exc()) # Return 500 Internal Server Error return jsonify({"message": "Error serving analysis image"}), 500 # --- Main Execution --- if __name__ == '__main__': # Ensure model loaded successfully before starting server if model: logger.info("Model loaded successfully. Starting Flask development server...") # Use debug=True for development (auto-reload, debugger) # Use debug=False for production! # host='0.0.0.0' makes it accessible on the network app.run(host='0.0.0.0', port=7860, debug=True) else: # This message should appear if load_trained_model returned None logger.critical("APPLICATION FAILED TO START: MODEL COULD NOT BE LOADED.") # Exit code 1 indicates an error exit(1)