fcu52005505 commited on
Commit
9701a8b
·
verified ·
1 Parent(s): 05c1031

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -22
app.py CHANGED
@@ -10,7 +10,7 @@ from da2.utils.base import load_config
10
  from da2.utils.model import load_model
11
  from da2.utils.io import read_cv2_image, torch_transform, tensorize
12
 
13
- # 引入真正的 Accelerator 來解決 logging 報錯
14
  from accelerate import Accelerator
15
  from accelerate.logging import get_logger
16
 
@@ -18,14 +18,13 @@ from accelerate.logging import get_logger
18
  # Global Initialization (Run once at startup)
19
  # ============================================================
20
  def initialize_app(config_path="configs/infer.json"):
21
- # 1. 初始化真正的 Accelerator
22
- # 它會自動偵測環境 (CPU 或 GPU) 並初始化全局狀態,解決 logging 錯誤
23
  accelerator = Accelerator()
24
 
25
  # 2. 載入 Config
26
  config = load_config(config_path)
27
 
28
- # 設定 Logger (這時 accelerate 已經初始化,get_logger 會正常運作)
29
  logger = get_logger(__name__, log_level="INFO")
30
  config.setdefault("env", {})
31
  config["env"]["logger"] = logger
@@ -33,8 +32,7 @@ def initialize_app(config_path="configs/infer.json"):
33
 
34
  accelerator.print(f"Running on device: {accelerator.device}")
35
 
36
- # 3. 載入模型 (Global Load)
37
- # 注意:我們在這裡就載入模型,而不是在推論時才載入
38
  model = load_model(config, accelerator)
39
  model = model.to(accelerator.device)
40
  model.eval()
@@ -47,14 +45,12 @@ try:
47
  print("Model loaded successfully!")
48
  except Exception as e:
49
  print(f"Error loading model: {e}")
50
- # 為了防止 Build 失敗,這裡可以設為 None,但在推論時會報錯
51
  CONFIG, ACCELERATOR, MODEL = None, None, None
52
 
53
  # ============================================================
54
  # Mask loader
55
  # ============================================================
56
  def read_mask_demo(mask_path, img_shape):
57
- # 處理無 Mask 的情況,回傳全 True 的 Mask
58
  if mask_path is None:
59
  if len(img_shape) == 3:
60
  return np.ones((img_shape[1], img_shape[2]), dtype=bool)
@@ -62,11 +58,15 @@ def read_mask_demo(mask_path, img_shape):
62
 
63
  mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
64
  if mask is None:
65
- # 如果讀取失敗,也回傳全 True
66
  if len(img_shape) == 3:
67
  return np.ones((img_shape[1], img_shape[2]), dtype=bool)
68
  return np.ones(img_shape[:2], dtype=bool)
69
 
 
 
 
 
 
70
  return mask > 0
71
 
72
  # ============================================================
@@ -78,40 +78,49 @@ def run_inference_and_save_depth(image_path, mask_path=None):
78
 
79
  device = ACCELERATOR.device
80
 
81
- # Read image
82
  cv2_img = read_cv2_image(image_path)
83
  if cv2_img is None:
84
  print(f"Error reading image: {image_path}")
85
  return None, None
86
 
 
 
 
 
 
87
  img_tensor = torch_transform(cv2_img) # CxHxW tensor
88
 
89
- # 確保 mask 形狀正確
90
  mask = read_mask_demo(mask_path, img_tensor.shape)
91
 
92
- # Move to device with expected dtype
93
  model_dtype = CONFIG.get("spherevit", {}).get("dtype", "float32")
94
  input_tensor = tensorize(img_tensor, model_dtype, device)
95
 
96
- # Autocast (GPU only)
97
  use_autocast = (device.type == "cuda")
98
  autocast_ctx = torch.autocast(device_type="cuda") if use_autocast else nullcontext()
99
 
 
100
  with autocast_ctx, torch.no_grad():
101
- # Model inference
102
  pred = MODEL(input_tensor)
103
 
104
- # Handle tuple outputs
105
  if isinstance(pred, (tuple, list)):
106
  pred = pred[0]
107
 
108
- # Always convert to float32 before numpy to avoid BF16/FP16 errors
109
  depth = pred.float().cpu().numpy()
110
 
111
- # Normalize depth -> 8-bit grayscale
112
  depth = np.squeeze(depth)
113
-
114
- # 處理 NaN 和極值
 
 
 
 
 
115
  dmin, dmax = float(np.nanmin(depth)), float(np.nanmax(depth))
116
 
117
  if dmax - dmin > 1e-6:
@@ -121,7 +130,7 @@ def run_inference_and_save_depth(image_path, mask_path=None):
121
 
122
  depth_8bit = (depth_norm * 255).astype(np.uint8)
123
 
124
- # Save output
125
  os.makedirs("outputs", exist_ok=True)
126
  base = os.path.splitext(os.path.basename(image_path))[0]
127
  out_path = f"outputs/{base}_depth.png"
@@ -138,7 +147,6 @@ def gradio_fn(image, mask):
138
  depth_img, out_path = run_inference_and_save_depth(image, mask)
139
  return depth_img, out_path
140
 
141
- # 定義 Gradio 介面
142
  demo = gr.Interface(
143
  fn=gradio_fn,
144
  inputs=[
@@ -150,7 +158,7 @@ demo = gr.Interface(
150
  gr.File(label="Download Depth PNG"),
151
  ],
152
  title="DA² — Minimal Depth Demo",
153
- description="Upload an image (and optional mask) -> outputs an 8-bit grayscale depth PNG.",
154
  allow_flagging="never",
155
  )
156
 
 
10
  from da2.utils.model import load_model
11
  from da2.utils.io import read_cv2_image, torch_transform, tensorize
12
 
13
+ # 引入真正的 Accelerator
14
  from accelerate import Accelerator
15
  from accelerate.logging import get_logger
16
 
 
18
  # Global Initialization (Run once at startup)
19
  # ============================================================
20
  def initialize_app(config_path="configs/infer.json"):
21
+ # 1. 初始化 Accelerator (解決 logging 報錯)
 
22
  accelerator = Accelerator()
23
 
24
  # 2. 載入 Config
25
  config = load_config(config_path)
26
 
27
+ # 設定 Logger
28
  logger = get_logger(__name__, log_level="INFO")
29
  config.setdefault("env", {})
30
  config["env"]["logger"] = logger
 
32
 
33
  accelerator.print(f"Running on device: {accelerator.device}")
34
 
35
+ # 3. 載入模型 (Global Load,避免每次推論重載)
 
36
  model = load_model(config, accelerator)
37
  model = model.to(accelerator.device)
38
  model.eval()
 
45
  print("Model loaded successfully!")
46
  except Exception as e:
47
  print(f"Error loading model: {e}")
 
48
  CONFIG, ACCELERATOR, MODEL = None, None, None
49
 
50
  # ============================================================
51
  # Mask loader
52
  # ============================================================
53
  def read_mask_demo(mask_path, img_shape):
 
54
  if mask_path is None:
55
  if len(img_shape) == 3:
56
  return np.ones((img_shape[1], img_shape[2]), dtype=bool)
 
58
 
59
  mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
60
  if mask is None:
 
61
  if len(img_shape) == 3:
62
  return np.ones((img_shape[1], img_shape[2]), dtype=bool)
63
  return np.ones(img_shape[:2], dtype=bool)
64
 
65
+ # 確保 Mask 尺寸也跟隨 Tensor (如果需要的話),但通常 Mask 是跟原圖
66
+ # 這裡簡單處理,如果尺寸不對稍微 resize 一下避免崩潰
67
+ if mask.shape[:2] != img_shape[-2:]:
68
+ mask = cv2.resize(mask, (img_shape[-1], img_shape[-2]), interpolation=cv2.INTER_NEAREST)
69
+
70
  return mask > 0
71
 
72
  # ============================================================
 
78
 
79
  device = ACCELERATOR.device
80
 
81
+ # 1. 讀取原始圖片
82
  cv2_img = read_cv2_image(image_path)
83
  if cv2_img is None:
84
  print(f"Error reading image: {image_path}")
85
  return None, None
86
 
87
+ # 【關鍵修復】獲取原始圖片的寬高 (Height, Width)
88
+ # cv2 shape 是 (H, W, C)
89
+ original_h, original_w = cv2_img.shape[:2]
90
+
91
+ # 2. 轉換為 Tensor (這步通常會 resize 成 518x518 或其他模型輸入尺寸)
92
  img_tensor = torch_transform(cv2_img) # CxHxW tensor
93
 
94
+ # 處理 Mask
95
  mask = read_mask_demo(mask_path, img_tensor.shape)
96
 
97
+ # 準備輸入
98
  model_dtype = CONFIG.get("spherevit", {}).get("dtype", "float32")
99
  input_tensor = tensorize(img_tensor, model_dtype, device)
100
 
101
+ # Autocast
102
  use_autocast = (device.type == "cuda")
103
  autocast_ctx = torch.autocast(device_type="cuda") if use_autocast else nullcontext()
104
 
105
+ # 3. 推論
106
  with autocast_ctx, torch.no_grad():
 
107
  pred = MODEL(input_tensor)
108
 
 
109
  if isinstance(pred, (tuple, list)):
110
  pred = pred[0]
111
 
112
+ # 轉回 Numpy float32
113
  depth = pred.float().cpu().numpy()
114
 
115
+ # 4. 後處理
116
  depth = np.squeeze(depth)
117
+
118
+ # 【關鍵修復】將深度圖 Resize 回原始尺寸
119
+ # cv2.resize 接受的參數是 (Width, Height)
120
+ if (depth.shape[0] != original_h) or (depth.shape[1] != original_w):
121
+ depth = cv2.resize(depth, (original_w, original_h), interpolation=cv2.INTER_CUBIC)
122
+
123
+ # 5. 正規化 (Normalization) -> 8-bit
124
  dmin, dmax = float(np.nanmin(depth)), float(np.nanmax(depth))
125
 
126
  if dmax - dmin > 1e-6:
 
130
 
131
  depth_8bit = (depth_norm * 255).astype(np.uint8)
132
 
133
+ # 6. 儲存
134
  os.makedirs("outputs", exist_ok=True)
135
  base = os.path.splitext(os.path.basename(image_path))[0]
136
  out_path = f"outputs/{base}_depth.png"
 
147
  depth_img, out_path = run_inference_and_save_depth(image, mask)
148
  return depth_img, out_path
149
 
 
150
  demo = gr.Interface(
151
  fn=gradio_fn,
152
  inputs=[
 
158
  gr.File(label="Download Depth PNG"),
159
  ],
160
  title="DA² — Minimal Depth Demo",
161
+ description="Upload an image (and optional mask) -> outputs an 8-bit grayscale depth PNG (Resized to Original).",
162
  allow_flagging="never",
163
  )
164