Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ from da2.utils.base import load_config
|
|
| 10 |
from da2.utils.model import load_model
|
| 11 |
from da2.utils.io import read_cv2_image, torch_transform, tensorize
|
| 12 |
|
| 13 |
-
# 引入真正的 Accelerator
|
| 14 |
from accelerate import Accelerator
|
| 15 |
from accelerate.logging import get_logger
|
| 16 |
|
|
@@ -18,14 +18,13 @@ from accelerate.logging import get_logger
|
|
| 18 |
# Global Initialization (Run once at startup)
|
| 19 |
# ============================================================
|
| 20 |
def initialize_app(config_path="configs/infer.json"):
|
| 21 |
-
# 1.
|
| 22 |
-
# 它會自動偵測環境 (CPU 或 GPU) 並初始化全局狀態,解決 logging 錯誤
|
| 23 |
accelerator = Accelerator()
|
| 24 |
|
| 25 |
# 2. 載入 Config
|
| 26 |
config = load_config(config_path)
|
| 27 |
|
| 28 |
-
# 設定 Logger
|
| 29 |
logger = get_logger(__name__, log_level="INFO")
|
| 30 |
config.setdefault("env", {})
|
| 31 |
config["env"]["logger"] = logger
|
|
@@ -33,8 +32,7 @@ def initialize_app(config_path="configs/infer.json"):
|
|
| 33 |
|
| 34 |
accelerator.print(f"Running on device: {accelerator.device}")
|
| 35 |
|
| 36 |
-
# 3. 載入模型 (Global Load)
|
| 37 |
-
# 注意:我們在這裡就載入模型,而不是在推論時才載入
|
| 38 |
model = load_model(config, accelerator)
|
| 39 |
model = model.to(accelerator.device)
|
| 40 |
model.eval()
|
|
@@ -47,14 +45,12 @@ try:
|
|
| 47 |
print("Model loaded successfully!")
|
| 48 |
except Exception as e:
|
| 49 |
print(f"Error loading model: {e}")
|
| 50 |
-
# 為了防止 Build 失敗,這裡可以設為 None,但在推論時會報錯
|
| 51 |
CONFIG, ACCELERATOR, MODEL = None, None, None
|
| 52 |
|
| 53 |
# ============================================================
|
| 54 |
# Mask loader
|
| 55 |
# ============================================================
|
| 56 |
def read_mask_demo(mask_path, img_shape):
|
| 57 |
-
# 處理無 Mask 的情況,回傳全 True 的 Mask
|
| 58 |
if mask_path is None:
|
| 59 |
if len(img_shape) == 3:
|
| 60 |
return np.ones((img_shape[1], img_shape[2]), dtype=bool)
|
|
@@ -62,11 +58,15 @@ def read_mask_demo(mask_path, img_shape):
|
|
| 62 |
|
| 63 |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
| 64 |
if mask is None:
|
| 65 |
-
# 如果讀取失敗,也回傳全 True
|
| 66 |
if len(img_shape) == 3:
|
| 67 |
return np.ones((img_shape[1], img_shape[2]), dtype=bool)
|
| 68 |
return np.ones(img_shape[:2], dtype=bool)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
return mask > 0
|
| 71 |
|
| 72 |
# ============================================================
|
|
@@ -78,40 +78,49 @@ def run_inference_and_save_depth(image_path, mask_path=None):
|
|
| 78 |
|
| 79 |
device = ACCELERATOR.device
|
| 80 |
|
| 81 |
-
#
|
| 82 |
cv2_img = read_cv2_image(image_path)
|
| 83 |
if cv2_img is None:
|
| 84 |
print(f"Error reading image: {image_path}")
|
| 85 |
return None, None
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
img_tensor = torch_transform(cv2_img) # CxHxW tensor
|
| 88 |
|
| 89 |
-
#
|
| 90 |
mask = read_mask_demo(mask_path, img_tensor.shape)
|
| 91 |
|
| 92 |
-
#
|
| 93 |
model_dtype = CONFIG.get("spherevit", {}).get("dtype", "float32")
|
| 94 |
input_tensor = tensorize(img_tensor, model_dtype, device)
|
| 95 |
|
| 96 |
-
# Autocast
|
| 97 |
use_autocast = (device.type == "cuda")
|
| 98 |
autocast_ctx = torch.autocast(device_type="cuda") if use_autocast else nullcontext()
|
| 99 |
|
|
|
|
| 100 |
with autocast_ctx, torch.no_grad():
|
| 101 |
-
# Model inference
|
| 102 |
pred = MODEL(input_tensor)
|
| 103 |
|
| 104 |
-
# Handle tuple outputs
|
| 105 |
if isinstance(pred, (tuple, list)):
|
| 106 |
pred = pred[0]
|
| 107 |
|
| 108 |
-
#
|
| 109 |
depth = pred.float().cpu().numpy()
|
| 110 |
|
| 111 |
-
#
|
| 112 |
depth = np.squeeze(depth)
|
| 113 |
-
|
| 114 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
dmin, dmax = float(np.nanmin(depth)), float(np.nanmax(depth))
|
| 116 |
|
| 117 |
if dmax - dmin > 1e-6:
|
|
@@ -121,7 +130,7 @@ def run_inference_and_save_depth(image_path, mask_path=None):
|
|
| 121 |
|
| 122 |
depth_8bit = (depth_norm * 255).astype(np.uint8)
|
| 123 |
|
| 124 |
-
#
|
| 125 |
os.makedirs("outputs", exist_ok=True)
|
| 126 |
base = os.path.splitext(os.path.basename(image_path))[0]
|
| 127 |
out_path = f"outputs/{base}_depth.png"
|
|
@@ -138,7 +147,6 @@ def gradio_fn(image, mask):
|
|
| 138 |
depth_img, out_path = run_inference_and_save_depth(image, mask)
|
| 139 |
return depth_img, out_path
|
| 140 |
|
| 141 |
-
# 定義 Gradio 介面
|
| 142 |
demo = gr.Interface(
|
| 143 |
fn=gradio_fn,
|
| 144 |
inputs=[
|
|
@@ -150,7 +158,7 @@ demo = gr.Interface(
|
|
| 150 |
gr.File(label="Download Depth PNG"),
|
| 151 |
],
|
| 152 |
title="DA² — Minimal Depth Demo",
|
| 153 |
-
description="Upload an image (and optional mask) -> outputs an 8-bit grayscale depth PNG.",
|
| 154 |
allow_flagging="never",
|
| 155 |
)
|
| 156 |
|
|
|
|
| 10 |
from da2.utils.model import load_model
|
| 11 |
from da2.utils.io import read_cv2_image, torch_transform, tensorize
|
| 12 |
|
| 13 |
+
# 引入真正的 Accelerator
|
| 14 |
from accelerate import Accelerator
|
| 15 |
from accelerate.logging import get_logger
|
| 16 |
|
|
|
|
| 18 |
# Global Initialization (Run once at startup)
|
| 19 |
# ============================================================
|
| 20 |
def initialize_app(config_path="configs/infer.json"):
|
| 21 |
+
# 1. 初始化 Accelerator (解決 logging 報錯)
|
|
|
|
| 22 |
accelerator = Accelerator()
|
| 23 |
|
| 24 |
# 2. 載入 Config
|
| 25 |
config = load_config(config_path)
|
| 26 |
|
| 27 |
+
# 設定 Logger
|
| 28 |
logger = get_logger(__name__, log_level="INFO")
|
| 29 |
config.setdefault("env", {})
|
| 30 |
config["env"]["logger"] = logger
|
|
|
|
| 32 |
|
| 33 |
accelerator.print(f"Running on device: {accelerator.device}")
|
| 34 |
|
| 35 |
+
# 3. 載入模型 (Global Load,避免每次推論重載)
|
|
|
|
| 36 |
model = load_model(config, accelerator)
|
| 37 |
model = model.to(accelerator.device)
|
| 38 |
model.eval()
|
|
|
|
| 45 |
print("Model loaded successfully!")
|
| 46 |
except Exception as e:
|
| 47 |
print(f"Error loading model: {e}")
|
|
|
|
| 48 |
CONFIG, ACCELERATOR, MODEL = None, None, None
|
| 49 |
|
| 50 |
# ============================================================
|
| 51 |
# Mask loader
|
| 52 |
# ============================================================
|
| 53 |
def read_mask_demo(mask_path, img_shape):
|
|
|
|
| 54 |
if mask_path is None:
|
| 55 |
if len(img_shape) == 3:
|
| 56 |
return np.ones((img_shape[1], img_shape[2]), dtype=bool)
|
|
|
|
| 58 |
|
| 59 |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
| 60 |
if mask is None:
|
|
|
|
| 61 |
if len(img_shape) == 3:
|
| 62 |
return np.ones((img_shape[1], img_shape[2]), dtype=bool)
|
| 63 |
return np.ones(img_shape[:2], dtype=bool)
|
| 64 |
|
| 65 |
+
# 確保 Mask 尺寸也跟隨 Tensor (如果需要的話),但通常 Mask 是跟原圖
|
| 66 |
+
# 這裡簡單處理,如果尺寸不對稍微 resize 一下避免崩潰
|
| 67 |
+
if mask.shape[:2] != img_shape[-2:]:
|
| 68 |
+
mask = cv2.resize(mask, (img_shape[-1], img_shape[-2]), interpolation=cv2.INTER_NEAREST)
|
| 69 |
+
|
| 70 |
return mask > 0
|
| 71 |
|
| 72 |
# ============================================================
|
|
|
|
| 78 |
|
| 79 |
device = ACCELERATOR.device
|
| 80 |
|
| 81 |
+
# 1. 讀取原始圖片
|
| 82 |
cv2_img = read_cv2_image(image_path)
|
| 83 |
if cv2_img is None:
|
| 84 |
print(f"Error reading image: {image_path}")
|
| 85 |
return None, None
|
| 86 |
|
| 87 |
+
# 【關鍵修復】獲取原始圖片的寬高 (Height, Width)
|
| 88 |
+
# cv2 shape 是 (H, W, C)
|
| 89 |
+
original_h, original_w = cv2_img.shape[:2]
|
| 90 |
+
|
| 91 |
+
# 2. 轉換為 Tensor (這步通常會 resize 成 518x518 或其他模型輸入尺寸)
|
| 92 |
img_tensor = torch_transform(cv2_img) # CxHxW tensor
|
| 93 |
|
| 94 |
+
# 處理 Mask
|
| 95 |
mask = read_mask_demo(mask_path, img_tensor.shape)
|
| 96 |
|
| 97 |
+
# 準備輸入
|
| 98 |
model_dtype = CONFIG.get("spherevit", {}).get("dtype", "float32")
|
| 99 |
input_tensor = tensorize(img_tensor, model_dtype, device)
|
| 100 |
|
| 101 |
+
# Autocast
|
| 102 |
use_autocast = (device.type == "cuda")
|
| 103 |
autocast_ctx = torch.autocast(device_type="cuda") if use_autocast else nullcontext()
|
| 104 |
|
| 105 |
+
# 3. 推論
|
| 106 |
with autocast_ctx, torch.no_grad():
|
|
|
|
| 107 |
pred = MODEL(input_tensor)
|
| 108 |
|
|
|
|
| 109 |
if isinstance(pred, (tuple, list)):
|
| 110 |
pred = pred[0]
|
| 111 |
|
| 112 |
+
# 轉回 Numpy float32
|
| 113 |
depth = pred.float().cpu().numpy()
|
| 114 |
|
| 115 |
+
# 4. 後處理
|
| 116 |
depth = np.squeeze(depth)
|
| 117 |
+
|
| 118 |
+
# 【關鍵修復】將深度圖 Resize 回原始尺寸
|
| 119 |
+
# cv2.resize 接受的參數是 (Width, Height)
|
| 120 |
+
if (depth.shape[0] != original_h) or (depth.shape[1] != original_w):
|
| 121 |
+
depth = cv2.resize(depth, (original_w, original_h), interpolation=cv2.INTER_CUBIC)
|
| 122 |
+
|
| 123 |
+
# 5. 正規化 (Normalization) -> 8-bit
|
| 124 |
dmin, dmax = float(np.nanmin(depth)), float(np.nanmax(depth))
|
| 125 |
|
| 126 |
if dmax - dmin > 1e-6:
|
|
|
|
| 130 |
|
| 131 |
depth_8bit = (depth_norm * 255).astype(np.uint8)
|
| 132 |
|
| 133 |
+
# 6. 儲存
|
| 134 |
os.makedirs("outputs", exist_ok=True)
|
| 135 |
base = os.path.splitext(os.path.basename(image_path))[0]
|
| 136 |
out_path = f"outputs/{base}_depth.png"
|
|
|
|
| 147 |
depth_img, out_path = run_inference_and_save_depth(image, mask)
|
| 148 |
return depth_img, out_path
|
| 149 |
|
|
|
|
| 150 |
demo = gr.Interface(
|
| 151 |
fn=gradio_fn,
|
| 152 |
inputs=[
|
|
|
|
| 158 |
gr.File(label="Download Depth PNG"),
|
| 159 |
],
|
| 160 |
title="DA² — Minimal Depth Demo",
|
| 161 |
+
description="Upload an image (and optional mask) -> outputs an 8-bit grayscale depth PNG (Resized to Original).",
|
| 162 |
allow_flagging="never",
|
| 163 |
)
|
| 164 |
|