Spaces:
Sleeping
Sleeping
| from typing import Tuple, Dict | |
| import gradio as gr | |
| import supervision as sv | |
| import numpy as np | |
| import cv2 | |
| from huggingface_hub import hf_hub_download | |
| from ultralytics import YOLO | |
| # Define models | |
| MODEL_OPTIONS = { | |
| "YOLOv11-Small": "medieval-yolo11s-seg.pt" | |
| } | |
| # Dictionary to store loaded models | |
| models: Dict[str, YOLO] = {} | |
| # Load all models | |
| for name, model_file in MODEL_OPTIONS.items(): | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id="johnlockejrr/medieval-manuscript-yolov11-seg", | |
| filename=model_file | |
| ) | |
| models[name] = YOLO(model_path) | |
| except Exception as e: | |
| print(f"Error loading model {name}: {str(e)}") | |
| # Create annotators | |
| LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK) | |
| MASK_ANNOTATOR = sv.MaskAnnotator() | |
| def process_masks(masks: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray: | |
| """Process and resize masks to target shape""" | |
| if masks is None: | |
| return None | |
| processed_masks = [] | |
| h, w = target_shape | |
| for mask in masks: | |
| # Resize mask to target dimensions | |
| resized_mask = cv2.resize(mask.astype(float), (w, h), interpolation=cv2.INTER_LINEAR) | |
| # Threshold to create binary mask | |
| processed_masks.append(resized_mask > 0.5) | |
| return np.array(processed_masks) | |
| def detect_and_annotate( | |
| image: np.ndarray, | |
| model_name: str, | |
| conf_threshold: float, | |
| iou_threshold: float | |
| ) -> np.ndarray: | |
| try: | |
| if image is None: | |
| return None | |
| model = models.get(model_name) | |
| if model is None: | |
| raise ValueError(f"Model {model_name} not loaded") | |
| # Perform inference | |
| results = model.predict( | |
| image, | |
| conf=conf_threshold, | |
| iou=iou_threshold | |
| )[0] | |
| # Convert results to supervision Detections | |
| boxes = results.boxes.xyxy.cpu().numpy() | |
| confidence = results.boxes.conf.cpu().numpy() | |
| class_ids = results.boxes.cls.cpu().numpy().astype(int) | |
| # Process masks | |
| masks = None | |
| if results.masks is not None: | |
| masks = results.masks.data.cpu().numpy() | |
| masks = np.moveaxis(masks, 0, -1) # Change from (N,H,W) to (H,W,N) | |
| masks = process_masks(masks, image.shape[:2]) | |
| # Create Detections object | |
| detections = sv.Detections( | |
| xyxy=boxes, | |
| confidence=confidence, | |
| class_id=class_ids, | |
| mask=masks | |
| ) | |
| # Create labels | |
| labels = [ | |
| f"{results.names[class_id]} ({conf:.2f})" | |
| for class_id, conf in zip(class_ids, confidence) | |
| ] | |
| # Annotate image | |
| annotated_image = image.copy() | |
| if masks is not None: | |
| annotated_image = MASK_ANNOTATOR.annotate( | |
| scene=annotated_image, | |
| detections=detections | |
| ) | |
| annotated_image = LABEL_ANNOTATOR.annotate( | |
| scene=annotated_image, | |
| detections=detections, | |
| labels=labels | |
| ) | |
| return annotated_image | |
| except Exception as e: | |
| print(f"Error during detection: {str(e)}") | |
| return image # Return original image on error | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Medieval Manuscript Segmentation with YOLO") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type='numpy') | |
| with gr.Accordion("Detection Settings", open=True): | |
| model_selector = gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| value=list(MODEL_OPTIONS.keys())[0], | |
| label="Model" | |
| ) | |
| conf_threshold = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.25 | |
| ) | |
| iou_threshold = gr.Slider( | |
| label="IoU Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.45 | |
| ) | |
| detect_btn = gr.Button("Detect", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Segmentation Result", type='numpy') | |
| def process_image(image, model_name, conf_threshold, iou_threshold): | |
| try: | |
| if image is None: | |
| return None, None | |
| annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold) | |
| return image, annotated_image | |
| except Exception as e: | |
| print(f"Error in process_image: {str(e)}") | |
| return image, image # Fallback to original image | |
| def clear(): | |
| return None, None | |
| detect_btn.click( | |
| process_image, | |
| inputs=[input_image, model_selector, conf_threshold, iou_threshold], | |
| outputs=[input_image, output_image] | |
| ) | |
| clear_btn.click( | |
| clear, | |
| inputs=None, | |
| outputs=[input_image, output_image] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| debug=True | |
| ) |