import os import cv2 import torch import numpy as np import gradio as gr from contextlib import nullcontext # DA² utilities from da2.utils.base import load_config from da2.utils.model import load_model from da2.utils.io import read_cv2_image, torch_transform, tensorize # 引入真正的 Accelerator from accelerate import Accelerator from accelerate.logging import get_logger # ============================================================ # Global Initialization (Run once at startup) # ============================================================ def initialize_app(config_path="configs/infer.json"): # 1. 初始化 Accelerator (解決 logging 報錯) accelerator = Accelerator() # 2. 載入 Config config = load_config(config_path) # 設定 Logger logger = get_logger(__name__, log_level="INFO") config.setdefault("env", {}) config["env"]["logger"] = logger config["env"].setdefault("seed", 42) accelerator.print(f"Running on device: {accelerator.device}") # 3. 載入模型 (Global Load,避免每次推論重載) model = load_model(config, accelerator) model = model.to(accelerator.device) model.eval() return config, accelerator, model # 初始化全局變數 try: CONFIG, ACCELERATOR, MODEL = initialize_app() print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") CONFIG, ACCELERATOR, MODEL = None, None, None # ============================================================ # Mask loader # ============================================================ def read_mask_demo(mask_path, img_shape): if mask_path is None: if len(img_shape) == 3: return np.ones((img_shape[1], img_shape[2]), dtype=bool) return np.ones(img_shape[:2], dtype=bool) mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) if mask is None: if len(img_shape) == 3: return np.ones((img_shape[1], img_shape[2]), dtype=bool) return np.ones(img_shape[:2], dtype=bool) # 確保 Mask 尺寸也跟隨 Tensor (如果需要的話),但通常 Mask 是跟原圖 # 這裡簡單處理,如果尺寸不對稍微 resize 一下避免崩潰 if mask.shape[:2] != img_shape[-2:]: mask = cv2.resize(mask, (img_shape[-1], img_shape[-2]), interpolation=cv2.INTER_NEAREST) return mask > 0 # ============================================================ # Core inference function # ============================================================ def run_inference_and_save_depth(image_path, mask_path=None): if MODEL is None: return None, "Error: Model not loaded." device = ACCELERATOR.device # 1. 讀取原始圖片 cv2_img = read_cv2_image(image_path) if cv2_img is None: print(f"Error reading image: {image_path}") return None, None # 【關鍵修復】獲取原始圖片的寬高 (Height, Width) # cv2 shape 是 (H, W, C) original_h, original_w = cv2_img.shape[:2] # 2. 轉換為 Tensor (這步通常會 resize 成 518x518 或其他模型輸入尺寸) img_tensor = torch_transform(cv2_img) # CxHxW tensor # 處理 Mask mask = read_mask_demo(mask_path, img_tensor.shape) # 準備輸入 model_dtype = CONFIG.get("spherevit", {}).get("dtype", "float32") input_tensor = tensorize(img_tensor, model_dtype, device) # Autocast use_autocast = (device.type == "cuda") autocast_ctx = torch.autocast(device_type="cuda") if use_autocast else nullcontext() # 3. 推論 with autocast_ctx, torch.no_grad(): pred = MODEL(input_tensor) if isinstance(pred, (tuple, list)): pred = pred[0] # 轉回 Numpy float32 depth = pred.float().cpu().numpy() # 4. 後處理 depth = np.squeeze(depth) # 【關鍵修復】將深度圖 Resize 回原始尺寸 # cv2.resize 接受的參數是 (Width, Height) if (depth.shape[0] != original_h) or (depth.shape[1] != original_w): depth = cv2.resize(depth, (original_w, original_h), interpolation=cv2.INTER_CUBIC) # 5. 正規化 (Normalization) -> 8-bit dmin, dmax = float(np.nanmin(depth)), float(np.nanmax(depth)) if dmax - dmin > 1e-6: depth_norm = (depth - dmin) / (dmax - dmin) else: depth_norm = np.zeros_like(depth, dtype=np.float32) depth_8bit = (depth_norm * 255).astype(np.uint8) # 6. 儲存 os.makedirs("outputs", exist_ok=True) base = os.path.splitext(os.path.basename(image_path))[0] out_path = f"outputs/{base}_depth.png" cv2.imwrite(out_path, depth_8bit) return depth_8bit, out_path # ============================================================ # Gradio UI # ============================================================ def gradio_fn(image, mask): if image is None: return None, None depth_img, out_path = run_inference_and_save_depth(image, mask) return depth_img, out_path demo = gr.Interface( fn=gradio_fn, inputs=[ gr.Image(label="Input Image", type="filepath"), gr.Image(label="Optional Mask", type="filepath"), ], outputs=[ gr.Image(label="Depth (8-bit Grayscale)", type="numpy"), gr.File(label="Download Depth PNG"), ], title="DA² — Minimal Depth Demo", description="Upload an image (and optional mask) -> outputs an 8-bit grayscale depth PNG (Resized to Original).", allow_flagging="never", ) if __name__ == "__main__": demo.launch()