|
|
import os
|
|
|
import json
|
|
|
import gc
|
|
|
import time
|
|
|
import traceback
|
|
|
from typing import Dict, List, Optional, Tuple, Callable, Any
|
|
|
|
|
|
import torch
|
|
|
import gradio as gr
|
|
|
import supervision as sv
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
try:
|
|
|
from transformers import (
|
|
|
AutoModelForCausalLM,
|
|
|
AutoTokenizer,
|
|
|
AutoModelForImageTextToText,
|
|
|
AutoProcessor,
|
|
|
BitsAndBytesConfig,
|
|
|
)
|
|
|
except Exception:
|
|
|
AutoModelForCausalLM = None
|
|
|
AutoTokenizer = None
|
|
|
AutoModelForImageTextToText = None
|
|
|
AutoProcessor = None
|
|
|
BitsAndBytesConfig = None
|
|
|
|
|
|
|
|
|
try:
|
|
|
from rfdetr import RFDETRMedium
|
|
|
except ImportError:
|
|
|
print("Warning: RF-DETR not found. Please ensure it's properly installed.")
|
|
|
RFDETRMedium = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpacesConfig:
|
|
|
"""Configuration optimized for Hugging Face Spaces."""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.settings = {
|
|
|
'results_dir': '/tmp/results',
|
|
|
'checkpoint': None,
|
|
|
'resolution': 576,
|
|
|
'threshold': 0.7,
|
|
|
'use_llm': True,
|
|
|
'llm_model_id': 'google/medgemma-4b-it',
|
|
|
'llm_max_new_tokens': 200,
|
|
|
'llm_temperature': 0.2,
|
|
|
'llm_4bit': True,
|
|
|
'enable_caching': True,
|
|
|
'max_cache_size': 100,
|
|
|
}
|
|
|
|
|
|
def get(self, key: str, default: Any = None) -> Any:
|
|
|
return self.settings.get(key, default)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryManager:
|
|
|
"""Simplified memory management for Spaces."""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.memory_thresholds = {
|
|
|
'gpu_warning': 0.8,
|
|
|
'system_warning': 0.85,
|
|
|
}
|
|
|
|
|
|
def cleanup_memory(self, force: bool = False) -> None:
|
|
|
"""Perform memory cleanup."""
|
|
|
try:
|
|
|
gc.collect()
|
|
|
if torch and torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.synchronize()
|
|
|
except Exception as e:
|
|
|
print(f"Memory cleanup error: {e}")
|
|
|
|
|
|
|
|
|
memory_manager = MemoryManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_checkpoint() -> Optional[str]:
|
|
|
"""Find RF-DETR checkpoint in various locations."""
|
|
|
candidates = [
|
|
|
"rf-detr-medium.pth",
|
|
|
"/tmp/results/checkpoint_best_total.pth",
|
|
|
"/tmp/results/checkpoint_best_ema.pth",
|
|
|
"/tmp/results/checkpoint_best_regular.pth",
|
|
|
"/tmp/results/checkpoint.pth",
|
|
|
]
|
|
|
|
|
|
for path in candidates:
|
|
|
if os.path.isfile(path):
|
|
|
return path
|
|
|
return None
|
|
|
|
|
|
def load_model(checkpoint_path: str, resolution: int):
|
|
|
"""Load RF-DETR model."""
|
|
|
if RFDETRMedium is None:
|
|
|
raise RuntimeError("RF-DETR not available. Please install it properly.")
|
|
|
|
|
|
model = RFDETRMedium(pretrain_weights=checkpoint_path, resolution=resolution)
|
|
|
try:
|
|
|
model.optimize_for_inference()
|
|
|
except Exception:
|
|
|
pass
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextGenerator:
|
|
|
"""Simplified text generator for Spaces."""
|
|
|
|
|
|
def __init__(self, model_id: str, max_tokens: int = 200, temperature: float = 0.2):
|
|
|
self.model_id = model_id
|
|
|
self.max_tokens = max_tokens
|
|
|
self.temperature = temperature
|
|
|
self.model = None
|
|
|
self.tokenizer = None
|
|
|
self.processor = None
|
|
|
self.is_multimodal = False
|
|
|
|
|
|
def load_model(self):
|
|
|
"""Load the LLM model."""
|
|
|
if self.model is not None:
|
|
|
return
|
|
|
|
|
|
if (AutoModelForCausalLM is None and AutoModelForImageTextToText is None):
|
|
|
raise RuntimeError("Transformers not available")
|
|
|
|
|
|
|
|
|
memory_manager.cleanup_memory()
|
|
|
|
|
|
print(f"Loading model: {self.model_id}")
|
|
|
|
|
|
model_kwargs = {
|
|
|
"device_map": "auto",
|
|
|
"low_cpu_mem_usage": True,
|
|
|
}
|
|
|
|
|
|
if torch and torch.cuda.is_available():
|
|
|
model_kwargs["torch_dtype"] = torch.bfloat16
|
|
|
|
|
|
|
|
|
if BitsAndBytesConfig is not None:
|
|
|
try:
|
|
|
compute_dtype = torch.bfloat16 if torch and torch.cuda.is_available() else torch.float16
|
|
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
|
load_in_4bit=True,
|
|
|
bnb_4bit_compute_dtype=compute_dtype,
|
|
|
bnb_4bit_use_double_quant=True,
|
|
|
bnb_4bit_quant_type="nf4"
|
|
|
)
|
|
|
model_kwargs["torch_dtype"] = compute_dtype
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
is_multimodal = "medgemma" in self.model_id.lower()
|
|
|
|
|
|
if is_multimodal and AutoModelForImageTextToText is not None and AutoProcessor is not None:
|
|
|
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
|
|
self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, **model_kwargs)
|
|
|
self.is_multimodal = True
|
|
|
elif AutoModelForCausalLM is not None and AutoTokenizer is not None:
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
|
|
|
self.is_multimodal = False
|
|
|
else:
|
|
|
raise RuntimeError("Required model classes not available")
|
|
|
|
|
|
print("β Model loaded successfully")
|
|
|
|
|
|
def generate(self, text: str, image: Optional[Image.Image] = None) -> str:
|
|
|
"""Generate text using the loaded model."""
|
|
|
self.load_model()
|
|
|
|
|
|
if self.model is None:
|
|
|
return f"[Model not loaded: {text}]"
|
|
|
|
|
|
try:
|
|
|
|
|
|
system_text = "You are a concise medical assistant. Provide a brief, clear summary of detection results. Avoid repetition and be direct. Do not give medical advice."
|
|
|
user_text = f"Summarize these detection results in 3 clear sentences:\n\n{text}"
|
|
|
|
|
|
if self.is_multimodal:
|
|
|
|
|
|
user_content = [{"type": "text", "text": user_text}]
|
|
|
if image is not None:
|
|
|
user_content.append({"type": "image", "image": image})
|
|
|
|
|
|
messages = [
|
|
|
{"role": "system", "content": [{"type": "text", "text": system_text}]},
|
|
|
{"role": "user", "content": user_content},
|
|
|
]
|
|
|
|
|
|
inputs = self.processor.apply_chat_template(
|
|
|
messages,
|
|
|
add_generation_prompt=True,
|
|
|
tokenize=True,
|
|
|
return_dict=True,
|
|
|
return_tensors="pt",
|
|
|
)
|
|
|
|
|
|
if torch:
|
|
|
inputs = inputs.to(self.model.device, dtype=torch.bfloat16)
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
generation = self.model.generate(
|
|
|
**inputs,
|
|
|
max_new_tokens=self.max_tokens,
|
|
|
do_sample=self.temperature > 0,
|
|
|
temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
|
|
|
use_cache=False,
|
|
|
)
|
|
|
|
|
|
input_len = inputs["input_ids"].shape[-1]
|
|
|
generation = generation[0][input_len:]
|
|
|
decoded = self.processor.decode(generation, skip_special_tokens=True)
|
|
|
return decoded.strip()
|
|
|
|
|
|
else:
|
|
|
|
|
|
messages = [
|
|
|
{"role": "system", "content": system_text},
|
|
|
{"role": "user", "content": user_text},
|
|
|
]
|
|
|
|
|
|
inputs = self.tokenizer.apply_chat_template(
|
|
|
messages,
|
|
|
add_generation_prompt=True,
|
|
|
tokenize=True,
|
|
|
return_dict=True,
|
|
|
return_tensors="pt",
|
|
|
)
|
|
|
|
|
|
inputs = inputs.to(self.model.device)
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
generation = self.model.generate(
|
|
|
**inputs,
|
|
|
max_new_tokens=self.max_tokens,
|
|
|
do_sample=self.temperature > 0,
|
|
|
temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
|
|
|
use_cache=False,
|
|
|
)
|
|
|
|
|
|
input_len = inputs["input_ids"].shape[-1]
|
|
|
generation = generation[0][input_len:]
|
|
|
decoded = self.tokenizer.decode(generation, skip_special_tokens=True)
|
|
|
return decoded.strip()
|
|
|
|
|
|
except Exception as e:
|
|
|
error_msg = f"[Generation error: {e}]"
|
|
|
print(f"Generation error: {traceback.format_exc()}")
|
|
|
return f"{error_msg}\n\n{text}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AppState:
|
|
|
"""Application state for Spaces."""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.config = SpacesConfig()
|
|
|
self.model = None
|
|
|
self.class_names = None
|
|
|
self.text_generator = None
|
|
|
|
|
|
def load_model(self):
|
|
|
"""Load the detection model."""
|
|
|
if self.model is not None:
|
|
|
return
|
|
|
|
|
|
checkpoint = find_checkpoint()
|
|
|
if not checkpoint:
|
|
|
raise FileNotFoundError(
|
|
|
"No RF-DETR checkpoint found. Please upload rf-detr-medium.pth to your Space."
|
|
|
)
|
|
|
|
|
|
print(f"Loading RF-DETR from: {checkpoint}")
|
|
|
self.model = load_model(checkpoint, self.config.get('resolution'))
|
|
|
|
|
|
|
|
|
try:
|
|
|
results_json = "/tmp/results/results.json"
|
|
|
if os.path.isfile(results_json):
|
|
|
with open(results_json, 'r') as f:
|
|
|
data = json.load(f)
|
|
|
classes = []
|
|
|
for split in ("valid", "test", "train"):
|
|
|
if "class_map" in data and split in data["class_map"]:
|
|
|
for item in data["class_map"][split]:
|
|
|
name = item.get("class")
|
|
|
if name and name != "all" and name not in classes:
|
|
|
classes.append(name)
|
|
|
self.class_names = classes if classes else None
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
print("β RF-DETR model loaded")
|
|
|
|
|
|
def get_text_generator(self, model_size: str = "4B") -> TextGenerator:
|
|
|
"""Get or create text generator."""
|
|
|
|
|
|
model_id = 'google/medgemma-27b-it' if model_size == "27B" else 'google/medgemma-4b-it'
|
|
|
|
|
|
|
|
|
if (self.text_generator is None or
|
|
|
hasattr(self.text_generator, 'model_id') and
|
|
|
self.text_generator.model_id != model_id):
|
|
|
|
|
|
max_tokens = self.config.get('llm_max_new_tokens')
|
|
|
temperature = self.config.get('llm_temperature')
|
|
|
|
|
|
self.text_generator = TextGenerator(model_id, max_tokens, temperature)
|
|
|
return self.text_generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_detection_interface():
|
|
|
"""Create the Gradio interface."""
|
|
|
|
|
|
|
|
|
COLOR_PALETTE = sv.ColorPalette.from_hex([
|
|
|
"#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2",
|
|
|
"#ff8080", "#b266ff", "#9999ff", "#66ffff", "#33ff99",
|
|
|
"#66ff66", "#99ff00",
|
|
|
])
|
|
|
|
|
|
def annotate_image(image: Image.Image, threshold: float, model_size: str = "4B") -> Tuple[Image.Image, str]:
|
|
|
"""Process an image and return annotated version with description."""
|
|
|
|
|
|
if image is None:
|
|
|
return None, "Please upload an image."
|
|
|
|
|
|
try:
|
|
|
|
|
|
app_state.load_model()
|
|
|
|
|
|
|
|
|
detections = app_state.model.predict(image, threshold=threshold)
|
|
|
|
|
|
|
|
|
bbox_annotator = sv.BoxAnnotator(color=COLOR_PALETTE, thickness=2)
|
|
|
label_annotator = sv.LabelAnnotator(text_scale=0.5, text_color=sv.Color.BLACK)
|
|
|
|
|
|
labels = []
|
|
|
for i in range(len(detections)):
|
|
|
class_id = int(detections.class_id[i]) if detections.class_id is not None else None
|
|
|
conf = float(detections.confidence[i]) if detections.confidence is not None else 0.0
|
|
|
|
|
|
if app_state.class_names and class_id is not None:
|
|
|
if 0 <= class_id < len(app_state.class_names):
|
|
|
label_name = app_state.class_names[class_id]
|
|
|
else:
|
|
|
label_name = str(class_id)
|
|
|
else:
|
|
|
label_name = str(class_id) if class_id is not None else "object"
|
|
|
|
|
|
labels.append(f"{label_name} {conf:.2f}")
|
|
|
|
|
|
annotated = image.copy()
|
|
|
annotated = bbox_annotator.annotate(annotated, detections)
|
|
|
annotated = label_annotator.annotate(annotated, detections, labels)
|
|
|
|
|
|
|
|
|
description = f"Found {len(detections)} detections above threshold {threshold}:\n\n"
|
|
|
|
|
|
if len(detections) > 0:
|
|
|
counts = {}
|
|
|
for i in range(len(detections)):
|
|
|
class_id = int(detections.class_id[i]) if detections.class_id is not None else None
|
|
|
if app_state.class_names and class_id is not None:
|
|
|
if 0 <= class_id < len(app_state.class_names):
|
|
|
name = app_state.class_names[class_id]
|
|
|
else:
|
|
|
name = str(class_id)
|
|
|
else:
|
|
|
name = str(class_id) if class_id is not None else "object"
|
|
|
counts[name] = counts.get(name, 0) + 1
|
|
|
|
|
|
for name, count in counts.items():
|
|
|
description += f"- {count}Γ {name}\n"
|
|
|
|
|
|
|
|
|
if app_state.config.get('use_llm'):
|
|
|
try:
|
|
|
generator = app_state.get_text_generator(model_size)
|
|
|
llm_description = generator.generate(description, image=annotated)
|
|
|
description = llm_description
|
|
|
except Exception as e:
|
|
|
description = f"[LLM error: {e}]\n\n{description}"
|
|
|
else:
|
|
|
description += "No objects detected above the confidence threshold."
|
|
|
|
|
|
return annotated, description
|
|
|
|
|
|
except Exception as e:
|
|
|
error_msg = f"Error processing image: {str(e)}"
|
|
|
print(f"Processing error: {traceback.format_exc()}")
|
|
|
return None, error_msg
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Medical Image Analysis", theme=gr.themes.Soft()) as demo:
|
|
|
gr.Markdown("# π₯ Medical Image Analysis")
|
|
|
gr.Markdown("Upload a medical image to detect and analyze findings using AI.")
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
|
input_image = gr.Image(type="pil", label="Upload Image", height=400)
|
|
|
threshold_slider = gr.Slider(
|
|
|
minimum=0.1,
|
|
|
maximum=1.0,
|
|
|
value=0.7,
|
|
|
step=0.05,
|
|
|
label="Confidence Threshold",
|
|
|
info="Higher values = fewer but more confident detections"
|
|
|
)
|
|
|
|
|
|
model_size_radio = gr.Radio(
|
|
|
choices=["4B", "27B"],
|
|
|
value="4B",
|
|
|
label="MedGemma Model Size",
|
|
|
info="4B: Faster, less memory | 27B: More accurate, more memory"
|
|
|
)
|
|
|
|
|
|
analyze_btn = gr.Button("π Analyze Image", variant="primary")
|
|
|
|
|
|
with gr.Column():
|
|
|
output_image = gr.Image(type="pil", label="Results", height=400)
|
|
|
output_text = gr.Textbox(
|
|
|
label="Analysis Results",
|
|
|
lines=8,
|
|
|
max_lines=15,
|
|
|
show_copy_button=True
|
|
|
)
|
|
|
|
|
|
|
|
|
analyze_btn.click(
|
|
|
fn=annotate_image,
|
|
|
inputs=[input_image, threshold_slider, model_size_radio],
|
|
|
outputs=[output_image, output_text]
|
|
|
)
|
|
|
|
|
|
|
|
|
input_image.change(
|
|
|
fn=annotate_image,
|
|
|
inputs=[input_image, threshold_slider, model_size_radio],
|
|
|
outputs=[output_image, output_text]
|
|
|
)
|
|
|
|
|
|
|
|
|
gr.Markdown("---")
|
|
|
gr.Markdown("*Powered by RF-DETR and MedGemma β’ Built for Hugging Face Spaces*")
|
|
|
|
|
|
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_state = AppState()
|
|
|
|
|
|
def main():
|
|
|
"""Main entry point for the Spaces app."""
|
|
|
print("π Starting Medical Image Analysis App")
|
|
|
|
|
|
|
|
|
os.makedirs(app_state.config.get('results_dir'), exist_ok=True)
|
|
|
|
|
|
|
|
|
demo = create_detection_interface()
|
|
|
|
|
|
|
|
|
demo.launch(
|
|
|
server_name="0.0.0.0",
|
|
|
server_port=7860,
|
|
|
share=False,
|
|
|
show_error=True,
|
|
|
show_api=False,
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|