File size: 31,983 Bytes
8d7600a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
# -*- coding: utf-8 -*-
import os
import base64
import io
import logging # For better logging
# Import specific handlers and formatter
from logging.handlers import RotatingFileHandler
import traceback # For detailed exception logging
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS # To handle Cross-Origin requests from your frontend
import torch
import cv2
import numpy as np
import yaml
from torchvision import transforms
from transformers import SegformerForSemanticSegmentation
from omegaconf import OmegaConf # Import OmegaConf itself
import torch.nn.functional as F
from werkzeug.utils import secure_filename # For safer filenames

# --- Configuration ---
# Use absolute paths for robustness
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Directory where this script is running

# >>> Point this to your actual config file <<<
CONFIG_PATH = os.path.join(BASE_DIR, "config/config.yaml") # Assuming config.yaml is in the same dir

# >>> Point this to your actual checkpoint file <<<
CHECKPOINT_PATH = "ckpt_000-vloss_0.4685_vf1_0.6469.ckpt"

UPLOAD_FOLDER = os.path.join(BASE_DIR, 'uploads')
RESULT_FOLDER = os.path.join(BASE_DIR, 'results')
LOG_FILE_PATH = os.path.join(BASE_DIR, 'flask_app.log') # Define log file path
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'bmp', 'tif', 'tiff'}

# --- Logging Setup ---
# Clear existing handlers from the root logger to avoid duplicates on reload
logging.getLogger().handlers.clear()
# Create formatter
log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s')
# Create Console Handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_formatter)
# Create File Handler (using RotatingFileHandler for log rotation)
file_handler = RotatingFileHandler(LOG_FILE_PATH, maxBytes=5*1024*1024, backupCount=3)
file_handler.setFormatter(log_formatter)
# Get the root logger and add handlers
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set minimum level for the logger (e.g., INFO, DEBUG)
logger.addHandler(console_handler)
logger.addHandler(file_handler)

# --- Ensure upload and result directories exist ---
try:
    os.makedirs(UPLOAD_FOLDER, exist_ok=True)
    os.makedirs(RESULT_FOLDER, exist_ok=True)
    logger.info(f"Ensured directories exist: {UPLOAD_FOLDER}, {RESULT_FOLDER}")
except OSError as e:
    logger.error(f"Error creating directories: {e}")
    exit(1) # Exit if we can't create essential folders

# --- Load Config ---
config = None
try:
    # Load the YAML file using OmegaConf
    config = OmegaConf.load(CONFIG_PATH)
    # Note: We don't need OmegaConf.create() if loading directly from file
    logger.info(f"Configuration loaded successfully from: {CONFIG_PATH}")
    # Log some key values to confirm loading
    logger.info(f"Config check: num_classes={config.data.num_classes}, model_name={config.training.model_name}")
except FileNotFoundError:
    logger.error(f"Configuration file not found: {CONFIG_PATH}")
    exit(1)
except Exception as e: # Catch broader errors during loading/parsing
    logger.error(f"Error loading or parsing configuration file '{CONFIG_PATH}': {e}")
    logger.error(traceback.format_exc())
    exit(1)

# --- Model Definition ---
class InferenceModel(torch.nn.Module):
    def __init__(self, model_config): # Use local name 'model_config'
        super().__init__()
        try:
            # Access config values needed for model init
            model_name = model_config.training.model_name
            num_classes = model_config.data.num_classes
            logger.info(f"Initializing SegformerForSemanticSegmentation with model='{model_name}', num_labels={num_classes}")
            self.model = SegformerForSemanticSegmentation.from_pretrained(
                model_name,
                num_labels=num_classes,
                ignore_mismatched_sizes=True # Important if fine-tuning head size differs
            )
            logger.info("Segformer model part initialized.")
        except AttributeError as ae:
             logger.error(f"Config error during model init: Missing key? {ae}")
             logger.error(f"Check if 'training.model_name' and 'data.num_classes' exist in {CONFIG_PATH}")
             raise # Re-raise error to stop execution
        except Exception as e:
            logger.error(f"Error initializing Segformer model from Hugging Face: {e}")
            logger.error(traceback.format_exc())
            raise # Re-raise error to stop execution

    def forward(self, x):
        # Expects pixel_values as input
        outputs = self.model(pixel_values=x, return_dict=True)
        # Upsample logits to original input size
        logits = F.interpolate(
            outputs.logits,
            size=x.shape[-2:], # Get H, W from input tensor x
            mode="bilinear",
            align_corners=False
        )
        return logits

# --- Utility Functions ---
def num_to_rgb(num_arr, color_map_dict):
    """Converts a label mask (numpy array) to an RGB color mask."""
    single_layer = np.squeeze(num_arr)
    output = np.zeros(num_arr.shape[:2] + (3,), dtype=np.uint8) # Initialize with uint8 zeros

    # Expects color_map_dict to be a standard Python dict {int_label: [R, G, B]}
    if not isinstance(color_map_dict, dict):
         logger.error(f"Invalid color_map provided to num_to_rgb: {type(color_map_dict)}. Expected dict.")
         return np.float32(output) / 255.0 # Return black float image

    unique_labels = np.unique(single_layer)
    for k in unique_labels:
        label_key = int(k) # Ensure key is standard int for lookup
        if label_key in color_map_dict:
             # Assign color, ensure color value is appropriate (e.g., list/tuple of 3 ints)
             color = color_map_dict[label_key]
             if isinstance(color, (list, tuple)) and len(color) == 3:
                 output[single_layer == k] = color
             else:
                 logger.warning(f"Invalid color format for label {label_key} in color map: {color}. Skipping.")
        else:
            if label_key != 0: # Often 0 is background, might not be in map
                 logger.warning(f"Label Key {label_key} found in mask but not in provided color map.")
            # Default color (e.g., black) is already set by np.zeros

    return np.float32(output) / 255.0 # Return float32 RGB image [0, 1]

def denormalize(tensor, mean, std):
    """Denormalizes a torch tensor (CHW format)."""
    # Expects standard Python lists/tuples for mean/std
    if not isinstance(mean, (list, tuple)) or not isinstance(std, (list, tuple)):
        logger.error(f"Mean ({type(mean)}) or std ({type(std)}) are not lists/tuples in denormalize.")
        return None
    # Input tensor expected shape: Batch, Channel, Height, Width (e.g., from dataloader or transform)
    if tensor.dim() != 4: # B C H W
        logger.error(f"Unexpected tensor dimension {tensor.dim()} in denormalize. Expected 4 (BCHW).")
        # Attempt to add batch dim if it's 3D (CHW)
        if tensor.dim() == 3:
            logger.warning("Denormalize received 3D tensor, adding batch dimension.")
            tensor = tensor.unsqueeze(0)
        else:
            return None # Cannot handle other dims

    num_channels = tensor.shape[1] # Channel dimension
    if len(mean) != num_channels or len(std) != num_channels:
         logger.error(f"Mean/std length ({len(mean)}/{len(std)}) mismatch with tensor channels ({num_channels})")
         return None

    # Clone to avoid modifying original tensor
    tensor = tensor.clone().cpu() # Work on CPU copy

    # Denormalize each channel
    for c in range(num_channels):
        tensor[:, c, :, :] = tensor[:, c, :, :] * std[c] + mean[c] # Apply to all items in batch

    # Clamp values, remove batch dimension, permute to HWC for display/saving
    # Assumes we are processing one image at a time here for inference result
    denormalized_img_tensor = torch.clamp(tensor.squeeze(0), 0, 1).permute(1, 2, 0)

    return denormalized_img_tensor.numpy() # Convert to numpy array (HWC, float32, [0,1])

# --- Load Model (Corrected Version) ---
def load_trained_model(checkpoint_path, model_config):
    """Loads the trained model from a checkpoint, handling potential key mismatches."""
    try:
        model_instance = InferenceModel(model_config) # Create model structure
        logger.info(f"Attempting to load checkpoint from: {checkpoint_path}")
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint file not found at specified path: {checkpoint_path}")

        checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
        logger.info(f"Checkpoint loaded into memory. Type: {type(checkpoint)}")

        # Extract the state dictionary - flexible based on common saving patterns
        if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
            state_dict = checkpoint["state_dict"]
            logger.info("Using 'state_dict' key from checkpoint.")
        elif isinstance(checkpoint, dict):
             # Assume the dict *is* the state_dict if 'state_dict' key is absent
             state_dict = checkpoint
             logger.info("Using checkpoint dictionary directly as state_dict (no 'state_dict' key found).")
        else:
            # Could be the model itself was saved directly (less common with frameworks)
            logger.warning(f"Checkpoint is not a dictionary. Attempting to load directly into model (less common). Type was: {type(checkpoint)}")
            # This path might need adjustment based on how the model was saved if not a state_dict
            try:
                 model_instance.load_state_dict(checkpoint) # Try loading directly
                 logger.info("Loaded state_dict directly from checkpoint object.")
                 model_instance.eval()
                 return model_instance
            except Exception as e:
                 logger.error(f"Failed to load state_dict directly from checkpoint object: {e}")
                 return None # Failed direct load

        # --- Key Prefix Correction Logic ---
        target_keys = set(model_instance.state_dict().keys())
        loaded_keys = set(state_dict.keys())
        if not loaded_keys: logger.warning("Loaded state_dict is empty!"); return None # Check if state_dict is empty
        first_loaded_key = next(iter(loaded_keys), None)
        first_target_key = next(iter(target_keys), None)
        corrected_state_dict = {}
        prefix_added = False

        # Check if prefix 'model.' needs to be ADDED to loaded keys
        if first_loaded_key and not first_loaded_key.startswith('model.') and \
           first_target_key and first_target_key.startswith('model.'):
            logger.warning("Checkpoint keys missing 'model.' prefix. Attempting to add it.")
            prefix_added = True
            keys_not_prefixed_properly = []
            for k, v in state_dict.items():
                new_key = f"model.{k}"
                if new_key in target_keys: corrected_state_dict[new_key] = v
                else: keys_not_prefixed_properly.append(k); corrected_state_dict[k] = v # Keep original if prefixed version not wanted
            if keys_not_prefixed_properly: logger.warning(f"Keys kept without prefix (target doesn't expect): {keys_not_prefixed_properly}")
            logger.info("Finished attempting prefix addition.")
        # Check if prefix 'model.' needs to be REMOVED from loaded keys
        elif first_loaded_key and first_loaded_key.startswith('model.') and \
             first_target_key and not first_target_key.startswith('model.'):
             logger.warning("Checkpoint keys HAVE 'model.' prefix, but target model DOES NOT. Attempting to remove it.")
             prefix_added = False # Indicate we removed prefix, not added
             keys_not_stripped_properly = []
             for k, v in state_dict.items():
                 if k.startswith('model.'):
                      new_key = k.partition('model.')[2] # Get part after 'model.'
                      if new_key in target_keys: corrected_state_dict[new_key] = v
                      else: keys_not_stripped_properly.append(k); corrected_state_dict[k] = v # Keep original if stripped version not wanted
                 else:
                      # Keep keys that didn't have prefix anyway
                      corrected_state_dict[k] = v
             if keys_not_stripped_properly: logger.warning(f"Keys kept with prefix (target doesn't expect stripped): {keys_not_stripped_properly}")
             logger.info("Finished attempting prefix removal.")
        else:
             logger.info("State dict keys seem to have correct prefix structure (or other mismatch). Using as is.")
             corrected_state_dict = state_dict # Use the original dict

        # --- Load the State Dictionary ---
        logger.info("Attempting to load state_dict with strict=False for checking...")
        missing_keys, unexpected_keys = model_instance.load_state_dict(corrected_state_dict, strict=False)

        # Report detailed findings
        final_msg = []
        is_load_successful = True
        if missing_keys:
             final_msg.append(f"MISSING keys in checkpoint: {missing_keys}")
             logger.error("CRITICAL FAILURE: Model is missing required keys.")
             is_load_successful = False
        if unexpected_keys:
             final_msg.append(f"UNEXPECTED keys in checkpoint (exist in file but not in model): {unexpected_keys}")
             # Decide if unexpected keys are acceptable
             acceptable_unexpected = [k for k in unexpected_keys if k.endswith('num_batches_tracked')]
             unacceptable_unexpected = [k for k in unexpected_keys if not k.endswith('num_batches_tracked')]
             if unacceptable_unexpected:
                 logger.error(f"CRITICAL FAILURE: Model received unacceptable unexpected keys: {unacceptable_unexpected}")
                 is_load_successful = False
             elif acceptable_unexpected:
                 logger.warning(f"Ignoring acceptable unexpected keys: {acceptable_unexpected}")

        if not is_load_successful:
            logger.error(f"State dict loading failed. Issues: {'; '.join(final_msg)}")
            return None # Failed to load properly

        logger.info(f"State dictionary loaded successfully. Issues (if any): {final_msg if final_msg else 'None'}")
        model_instance.eval() # Set to evaluation mode
        logger.info(f"Model loading process complete for {checkpoint_path}")
        return model_instance

    except FileNotFoundError as fnf_error:
        logger.error(f"{fnf_error}") # Log the specific FileNotFoundError message
        return None
    except Exception as e:
        logger.error(f"Unexpected error during model loading: {e}")
        logger.error(traceback.format_exc()) # Log full traceback
        return None


# --- Determine device & Load Model Globally ---
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# Load the model using the global config object
model = load_trained_model(CHECKPOINT_PATH, config) # Pass the loaded config
if model is None:
    logger.critical("CRITICAL: Failed to load model. Application cannot continue.")
    exit(1) # Critical error, stop the application
model.to(device) # Move model to the appropriate device

# --- Inference Pipeline (Corrected Config Handling) ---
def run_inference_on_bytes(image_bytes, inference_model, model_config, device):
    """Runs inference on image bytes, returns denormalized image, color mask, and overlay."""
    try:
        nparr = np.frombuffer(image_bytes, np.uint8)
        img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        if img is None: logger.error("Failed cv2.imdecode."); return None, None, None
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        logger.debug("Image decoded and converted to RGB.")

        # --- Preprocessing (with config conversion) ---
        # Check necessary config attributes exist before conversion attempt
        required_data_keys = ['image_size', 'mean', 'std', 'num_classes']
        for key in required_data_keys:
             if not OmegaConf.select(model_config, f'data.{key}', default=None):
                 logger.error(f"Config missing required data field: data.{key}")
                 return None, None, None
        if not OmegaConf.select(model_config, 'id2color', default=None):
             logger.error("Config missing required field: id2color")
             return None, None, None
        if not OmegaConf.select(model_config, 'training.model_name', default=None):
             logger.error("Config missing required field: training.model_name")
             return None, None, None

        try:
            # Convert OmegaConf structures to standard Python types using OmegaConf.to_container
            # resolve=True handles variable interpolation (like ${data.base_dir}) if used in relevant fields
            img_size = tuple(OmegaConf.to_container(model_config.data.image_size, resolve=True))
            mean = list(OmegaConf.to_container(model_config.data.mean, resolve=True))
            std = list(OmegaConf.to_container(model_config.data.std, resolve=True))
            # Ensure keys in id2color are standard integers
            id2color_map = {int(k): v for k, v in OmegaConf.to_container(model_config.id2color, resolve=True).items()}
            num_classes = int(model_config.data.num_classes) # Ensure int

            logger.debug(f"Converted config values: size={img_size}, mean={mean}, std={std}, id2color keys={list(id2color_map.keys())}, num_classes={num_classes}")

            # Basic validation after conversion
            if not isinstance(mean, list) or not isinstance(std, list) or not isinstance(id2color_map, dict): raise TypeError("Config values did not convert to list/dict.")
            if len(mean) != 3 or len(std) != 3: raise ValueError(f"Incorrect mean/std length. Expected 3.") # Assuming 3 color channels
            if len(img_size) != 2: raise ValueError(f"Incorrect image_size length. Expected 2 (H, W).")

        except Exception as e:
            logger.error(f"Error processing/converting configuration values: {e}")
            logger.error(traceback.format_exc())
            return None, None, None

        # Define the image transformation pipeline
        transform = transforms.Compose([
            transforms.ToTensor(), # HWC [0,255] numpy -> CHW [0,1] torch
            transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR), # Use converted tuple size, BILINEAR is common for images before model
            transforms.Normalize(mean=mean, std=std) # Use converted lists
        ])
        logger.debug(f"Image transform applied for size {img_size}.")
        input_tensor = transform(img_rgb).unsqueeze(0).to(device) # Add batch dim (B=1), move to device
        logger.debug(f"Input tensor created with shape: {input_tensor.shape}") # Should be [1, 3, H, W]

        # --- Run Prediction ---
        with torch.no_grad():
            logits = inference_model(input_tensor) # Expect [B, C, H, W] logits
            logger.debug(f"Logits received with shape: {logits.shape}")
            # Check logits shape again after potential upsampling in model forward
            if logits.dim() != 4 or logits.shape[1] != num_classes:
                 logger.error(f"Unexpected final logits shape or class number: {logits.shape}. Expected B x {num_classes} x H x W.")
                 return None, None, None
            # Argmax along class dimension (C), remove batch dim, move to CPU, convert type
            pred_mask = logits.argmax(1).squeeze(0).cpu().numpy().astype(np.uint8) # H W, uint8
            logger.debug(f"Prediction mask generated with shape: {pred_mask.shape}") # Should be [H, W]

        # --- Post-processing ---
        color_mask = num_to_rgb(pred_mask, id2color_map) # Use converted map
        if color_mask is None: logger.error("num_to_rgb failed."); return None, None, None
        logger.debug("Color mask generated.")

        # Denormalize the *input tensor* for overlay display
        denorm_img = denormalize(input_tensor, mean, std) # Use converted mean/std
        if denorm_img is None: logger.error("denormalize failed."); return None, None, None
        logger.debug("Input tensor denormalized for overlay.") # HWC, float32, [0,1]

        # --- Create Overlay ---
        # Ensure shapes match before blending (resize color mask to match denorm_img)
        if denorm_img.shape[:2] != color_mask.shape[:2]:
             logger.warning(f"Denorm img shape {denorm_img.shape[:2]} != Color mask shape {color_mask.shape[:2]}. Resizing color mask using INTER_NEAREST.")
             # Resize color_mask (HWC float32) to match denorm_img (HWC float32)
             color_mask = cv2.resize(color_mask, (denorm_img.shape[1], denorm_img.shape[0]), interpolation=cv2.INTER_NEAREST) # Use INTER_NEAREST for label masks

        # Blend images: Original (denorm_img) * alpha + Mask (color_mask) * beta + gamma
        overlay = cv2.addWeighted(denorm_img, 0.7, color_mask, 0.3, 0)
        logger.debug("Overlay created using cv2.addWeighted.")
        # overlay is HWC, float32, [0, 1], RGB

        return denorm_img, color_mask, overlay

    except Exception as e:
        logger.error(f"Exception during inference pipeline for image: {e}")
        logger.error(traceback.format_exc())
        return None, None, None


# --- Flask App ---
app = Flask(__name__)
CORS(app) # Allow all origins for API and Result routes resources={r"/api/*": {"origins": "*"}, r"/Result/*": {"origins": "*"}}
logger.info("Flask app created and CORS enabled.")

def allowed_file(filename):
    """Checks if the filename has an allowed extension."""
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

# --- API Endpoints ---
@app.route('/api/analyze', methods=['POST'])
def analyze_image():
    """Receives Base64 image, runs inference, saves original and overlay."""
    global model, config, device # Access global vars
    endpoint_log_prefix = "[POST /api/analyze]"
    logger.info(f"{endpoint_log_prefix} Received request.")

    # --- Basic Checks ---
    if model is None: logger.error(f"{endpoint_log_prefix} Model not loaded."); return jsonify({"success": False, "message": "Model not loaded"}), 500
    if not request.is_json: logger.warning(f"{endpoint_log_prefix} Not JSON."); return jsonify({"success": False, "message": "Request must be JSON"}), 400
    data = request.get_json()
    if not data or 'image' not in data or 'filename' not in data:
        logger.warning(f"{endpoint_log_prefix} Missing image/filename in JSON body. Data received: {data}")
        return jsonify({"success": False, "message": "Missing 'image' (base64) or 'filename' in JSON body"}), 400

    base64_image_data = data['image']; original_filename = data['filename']
    logger.info(f"{endpoint_log_prefix} Original filename from request: '{original_filename}'")
    safe_original_filename = secure_filename(original_filename) # Sanitize
    if not safe_original_filename or not allowed_file(safe_original_filename):
        logger.warning(f"{endpoint_log_prefix} Invalid/disallowed filename after sanitization: '{safe_original_filename}' from '{original_filename}'")
        return jsonify({"success": False, "message": "Invalid or disallowed filename/extension"}), 400
    logger.info(f"{endpoint_log_prefix} Sanitized filename for saving/processing: '{safe_original_filename}'")

    try:
        # --- Decode Base64 ---
        if ',' in base64_image_data: header, encoded = base64_image_data.split(',', 1)
        else: encoded = base64_image_data # Assume no header
        image_bytes = base64.b64decode(encoded)
        logger.info(f"{endpoint_log_prefix} Base64 image decoded ({len(image_bytes)} bytes).")

        # --- Save Original Image ---
        original_path = os.path.join(UPLOAD_FOLDER, safe_original_filename)
        try:
            with open(original_path, "wb") as f: f.write(image_bytes)
            logger.info(f"{endpoint_log_prefix} Original image saved to: '{original_path}'")
        except Exception as e:
            logger.error(f"{endpoint_log_prefix} Failed to save original image to '{original_path}': {e}")
            return jsonify({"success": False, "message": "Failed to save uploaded image on server"}), 500

        # --- Run Inference ---
        logger.info(f"{endpoint_log_prefix} Starting inference for '{safe_original_filename}'...")
        # Pass the global config object here
        denorm_img, color_mask, overlay = run_inference_on_bytes(image_bytes, model, config, device)
        if overlay is None: # Check if inference failed
            logger.error(f"{endpoint_log_prefix} Inference pipeline returned None for '{safe_original_filename}'.")
            return jsonify({"success": False, "message": "Inference process failed on server"}), 500
        logger.info(f"{endpoint_log_prefix} Inference completed successfully for '{safe_original_filename}'.")

        # --- Save Overlay Image ---
        name_part, ext = os.path.splitext(safe_original_filename)
        # Create consistent overlay filename (crucial for toggle endpoint)
        overlay_filename = f"analyzed_{name_part}{ext}"
        overlay_path = os.path.join(RESULT_FOLDER, overlay_filename)
        logger.info(f"{endpoint_log_prefix} Determined overlay filename: '{overlay_filename}' -> path: '{overlay_path}'")

        # Convert overlay (float32 HWC RGB [0,1]) to uint8 HWC BGR [0,255] for cv2.imwrite
        try:
            overlay_to_save_uint8 = (overlay * 255).astype(np.uint8)
            overlay_to_save_bgr = cv2.cvtColor(overlay_to_save_uint8, cv2.COLOR_RGB2BGR)
            save_success = cv2.imwrite(overlay_path, overlay_to_save_bgr)
            if not save_success:
                raise IOError(f"cv2.imwrite failed to save the overlay image to {overlay_path}")
            logger.info(f"{endpoint_log_prefix} Overlay image saved successfully to: '{overlay_path}'")
        except Exception as e:
            logger.error(f"{endpoint_log_prefix} Failed to convert or save overlay image to '{overlay_path}': {e}")
            logger.error(traceback.format_exc())
            return jsonify({"success": False, "message": "Failed to save analysis result image"}), 500

        # --- Success Response ---
        logger.info(f"{endpoint_log_prefix} Analysis successful for '{safe_original_filename}'. Returning success.")
        return jsonify({
            "success": True,
            "message": "Analysis complete",
            # Optionally return relative paths for info, client mainly needs overlay_filename
            "paths": {"original": os.path.relpath(original_path, BASE_DIR), "overlay": os.path.relpath(overlay_path, BASE_DIR)},
            "overlay_filename": overlay_filename # Return the *exact* filename saved
        }), 200

    except base64.binascii.Error as e:
        logger.error(f"{endpoint_log_prefix} Invalid Base64 data received: {e}")
        return jsonify({"success": False, "message": "Invalid Base64 image data received"}), 400
    except Exception as e:
        logger.error(f"{endpoint_log_prefix} Unexpected error during analysis request processing: {e}")
        logger.error(traceback.format_exc())
        return jsonify({"success": False, "message": "Internal server error during analysis processing"}), 500


@app.route('/api/toggle-image', methods=['GET'])
def get_analysis_path():
    """Checks if the analyzed version of a given original filename exists."""
    endpoint_log_prefix = "[GET /api/toggle-image]"
    logger.info(f"{endpoint_log_prefix} Received request.")
    logger.info(f"{endpoint_log_prefix} Full request URL: {request.url}")
    logger.info(f"{endpoint_log_prefix} Request Query Args: {request.args}") # Log received args

    original_filename = request.args.get('filename') # Get filename from ?filename=...
    if not original_filename:
        logger.warning(f"{endpoint_log_prefix} Missing 'filename' query parameter.")
        return jsonify({"message": "Missing 'filename' query parameter"}), 400

    logger.info(f"{endpoint_log_prefix} Original filename received from query: '{original_filename}'")
    safe_original_filename = secure_filename(original_filename) # Sanitize
    if not safe_original_filename:
         logger.warning(f"{endpoint_log_prefix} Invalid filename after sanitization: '{safe_original_filename}' from '{original_filename}'")
         return jsonify({"message": "Invalid filename format"}), 400
    logger.info(f"{endpoint_log_prefix} Sanitized filename for lookup: '{safe_original_filename}'")

    # --- Construct Expected Overlay Path (MUST match /analyze logic) ---
    name_part, ext = os.path.splitext(safe_original_filename)
    expected_overlay_filename = f"analyzed_{name_part}{ext}"
    expected_overlay_path = os.path.join(RESULT_FOLDER, expected_overlay_filename)
    logger.info(f"{endpoint_log_prefix} Expecting overlay file at: '{expected_overlay_path}'")

    # --- Check if File Exists ---
    if os.path.exists(expected_overlay_path):
        logger.info(f"{endpoint_log_prefix} Found analysis result file: '{expected_overlay_filename}'")
        # Return just the filename, client constructs the full /Result/ URL
        return jsonify({"filepath": expected_overlay_filename}), 200
    else:
        # Explicitly log the path that was checked and not found
        logger.warning(f"{endpoint_log_prefix} Analysis result file NOT FOUND at checked path: '{expected_overlay_path}'")
        # Return 404 Not Found status code
        return jsonify({"message": f"Analysis result not found for '{original_filename}'"}), 404


@app.route('/Result/<filename>')
def serve_result_image(filename):
    """Serves images from the RESULT_FOLDER."""
    endpoint_log_prefix = "[GET /Result]"
    # Sanitize filename received in URL path for security
    safe_filename = secure_filename(filename)
    if safe_filename != filename:
         # Log if the requested filename was changed by sanitization
         logger.warning(f"{endpoint_log_prefix} Requested filename '{filename}' was sanitized to '{safe_filename}'. Serving sanitized version.")

    logger.info(f"{endpoint_log_prefix} Attempting to serve file: '{safe_filename}' from directory: '{RESULT_FOLDER}'")
    try:
        # Use Flask's send_from_directory - safer than manual path joining
        # as_attachment=False means display in browser if possible
        return send_from_directory(RESULT_FOLDER, safe_filename, as_attachment=False)
    except FileNotFoundError:
        # Log the specific file that was not found
        logger.error(f"{endpoint_log_prefix} Requested file not found in result folder: '{safe_filename}'")
        # Return 404 Not Found
        return jsonify({"message": "Requested analysis image not found"}), 404
    except Exception as e:
        # Catch other potential errors (e.g., permission issues)
        logger.error(f"{endpoint_log_prefix} Error serving file '{safe_filename}': {e}")
        logger.error(traceback.format_exc())
        # Return 500 Internal Server Error
        return jsonify({"message": "Error serving analysis image"}), 500


# --- Main Execution ---
if __name__ == '__main__':
    # Ensure model loaded successfully before starting server
    if model:
        logger.info("Model loaded successfully. Starting Flask development server...")
        # Use debug=True for development (auto-reload, debugger)
        # Use debug=False for production!
        # host='0.0.0.0' makes it accessible on the network
        app.run(host='0.0.0.0', port=7860, debug=True)
    else:
        # This message should appear if load_trained_model returned None
        logger.critical("APPLICATION FAILED TO START: MODEL COULD NOT BE LOADED.")
        # Exit code 1 indicates an error
        exit(1)