kunkk commited on
Commit
a548e24
·
verified ·
1 Parent(s): a357391

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -72
app.py CHANGED
@@ -17,42 +17,45 @@ import base64
17
  # GPU/CPU设置
18
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19
 
 
20
  # CSS样式设置
21
  custom_css = """
22
  :root {
23
  --primary-color: #2196F3;
24
  --secondary-color: #21CBF3;
25
  --background-color: #f6f8fa;
26
- --text-color: #333;
27
  --border-radius: 10px;
28
- --glass-bg: rgba(255, 255, 255, 0.25);
29
- --shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37);
 
 
30
  }
31
 
32
  .gradio-container {
33
- background: linear-gradient(135deg, var(--background-color), #ffffff);
34
  max-width: 1400px !important;
35
  margin: auto !important;
36
- backdrop-filter: blur(10px);
 
37
  }
38
 
39
  .output-image, .input-image {
40
  border-radius: var(--border-radius);
41
- box-shadow: var(--shadow);
42
  transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
43
- backdrop-filter: blur(10px);
44
- border: 1px solid rgba(255, 255, 255, 0.18);
45
  }
46
 
47
  .output-image:hover, .input-image:hover {
48
  transform: scale(1.02) translateY(-2px);
49
- box-shadow: 0 12px 40px 0 rgba(31, 38, 135, 0.5);
50
  }
51
 
52
  .custom-button {
53
- background: linear-gradient(45deg, var(--primary-color), var(--secondary-color));
54
- border: none;
55
- color: white;
56
  padding: 12px 24px;
57
  border-radius: var(--border-radius);
58
  cursor: pointer;
@@ -60,42 +63,41 @@ custom_css = """
60
  font-weight: bold;
61
  text-transform: uppercase;
62
  letter-spacing: 1px;
63
- box-shadow: var(--shadow);
64
  }
65
 
66
  .custom-button:hover {
67
- transform: translateY(-3px);
68
- box-shadow: 0 12px 30px rgba(33, 150, 243, 0.4);
 
69
  }
70
 
71
  .advanced-controls {
72
- background: var(--glass-bg);
73
- border-radius: 20px;
74
  padding: 25px;
75
- box-shadow: var(--shadow);
76
- backdrop-filter: blur(10px);
77
- border: 1px solid rgba(255, 255, 255, 0.18);
78
  }
79
 
80
  .result-container {
81
- background: var(--glass-bg);
82
- border-radius: 20px;
83
  padding: 20px;
84
- backdrop-filter: blur(15px);
85
- border: 1px solid rgba(255, 255, 255, 0.18);
86
- box-shadow: var(--shadow);
87
  }
88
 
89
  .interactive-viz {
90
  border-radius: 15px;
91
  overflow: hidden;
92
  transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
93
- box-shadow: var(--shadow);
 
94
  }
95
 
96
  .interactive-viz:hover {
97
- transform: translateY(-5px);
98
- box-shadow: 0 15px 35px rgba(0,0,0,0.15);
99
  }
100
 
101
  .statistics-container {
@@ -106,35 +108,89 @@ custom_css = """
106
  }
107
 
108
  .statistic-card {
109
- background: var(--glass-bg);
110
  padding: 20px;
111
  border-radius: var(--border-radius);
112
  text-align: center;
113
- box-shadow: var(--shadow);
114
- backdrop-filter: blur(10px);
115
- border: 1px solid rgba(255, 255, 255, 0.18);
116
  transition: all 0.3s ease;
 
 
117
  }
118
 
119
  .statistic-card:hover {
120
  transform: translateY(-2px);
121
- box-shadow: 0 10px 25px rgba(0,0,0,0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  }
123
 
124
  .progress-container {
125
- background: var(--glass-bg);
126
  border-radius: 10px;
127
  padding: 15px;
128
  margin: 10px 0;
129
- backdrop-filter: blur(10px);
 
130
  }
131
 
132
  .comparison-slider {
133
- background: var(--glass-bg);
134
  border-radius: 15px;
135
  padding: 20px;
136
- backdrop-filter: blur(10px);
137
- border: 1px solid rgba(255, 255, 255, 0.18);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  }
139
  """
140
  class ImageProcessor:
@@ -189,49 +245,102 @@ class ImageProcessor:
189
  return image
190
 
191
  def generate_analysis_plots(self, saliency_map):
192
- """生成分析图表"""
193
- plt.style.use('seaborn-v0_8')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
195
- # 直方图
196
- ax1.hist(saliency_map.flatten(), bins=50, color='#2196F3', alpha=0.7, edgecolor='black')
197
- ax1.set_title('Histogram of Saliency Distribution', fontsize=12, pad=15)
198
- ax1.set_xlabel('Saliency Value', fontsize=10)
199
- ax1.set_ylabel('Frequency', fontsize=10)
200
- ax1.grid(True, alpha=0.3)
 
 
 
 
 
 
201
  # 添加统计信息
202
  mean_val = np.mean(saliency_map)
203
  median_val = np.median(saliency_map)
204
- ax1.axvline(mean_val, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_val:.3f}')
205
- ax1.axvline(median_val, color='green', linestyle='--', alpha=0.7, label=f'Median: {median_val:.3f}')
206
- ax1.legend()
 
 
 
207
  # 累积分布
208
  sorted_vals = np.sort(saliency_map.flatten())
209
  cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals)
210
- ax2.plot(sorted_vals, cumulative, color='#FF6B35', linewidth=2)
211
- ax2.set_title('Cumulative Distribution Function', fontsize=12)
212
- ax2.set_xlabel('Saliency Value', fontsize=10)
213
- ax2.set_ylabel('Cumulative Probability', fontsize=10)
214
- ax2.grid(True, alpha=0.3)
 
 
 
215
  # 箱线图
216
- ax3.boxplot(saliency_map.flatten(), patch_artist=True,
217
- boxprops=dict(facecolor='#21CBF3', alpha=0.7))
218
- ax3.set_title('Boxplot of Saliency Distribution', fontsize=12)
219
- ax3.set_ylabel('Saliency Value', fontsize=10)
220
- ax3.grid(True, alpha=0.3)
 
 
 
 
 
 
 
 
221
  # 强度剖面(中心线)
222
  center_row = saliency_map[saliency_map.shape[0]//2, :]
223
- ax4.plot(center_row, color='#9C27B0', linewidth=2)
224
- ax4.set_title('Intensity Profile along Center Line', fontsize=12)
225
- ax4.set_xlabel('Pixel Position', fontsize=10)
226
- ax4.set_ylabel('Saliency Value', fontsize=10)
227
- ax4.grid(True, alpha=0.3)
228
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
229
  # 保存为字节
230
  buf = io.BytesIO()
231
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
 
232
  buf.seek(0)
233
  img_array = np.array(Image.open(buf))
234
  plt.close()
 
 
 
 
235
  return img_array
236
 
237
  def quick_process(self, image, threshold=0.5, testsize=256):
@@ -456,16 +565,19 @@ def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours,
456
  overlayed = overlay_with_contours
457
 
458
  # 生成统计信息HTML
 
459
  if stats:
460
- stats_html = "<div class='statistics-container'>"
461
  for key, value in stats.items():
462
- stats_html += f"<div class='statistic-card'><h4>{key}</h4><p>{value}</p></div>"
 
 
 
 
 
463
  stats_html += "</div>"
464
  else:
465
- stats_html = "<p>无可用统计信息</p>"
466
-
467
- return (original, saliency_map, heatmap, overlayed, segmented,
468
- comparison, time_info, stats_html, analysis_plot)
469
 
470
  def create_comparison_view(original, result, slider_value):
471
  """创建滑块对比视图"""
 
17
  # GPU/CPU设置
18
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19
 
20
+ # CSS样式设置
21
  # CSS样式设置
22
  custom_css = """
23
  :root {
24
  --primary-color: #2196F3;
25
  --secondary-color: #21CBF3;
26
  --background-color: #f6f8fa;
27
+ --text-color: #000000;
28
  --border-radius: 10px;
29
+ --glass-bg: rgba(255, 255, 255, 0.95);
30
+ --shadow: 0 4px 16px 0 rgba(0, 0, 0, 0.2);
31
+ --print-border: 2px solid #000000;
32
+ --print-bg: #ffffff;
33
  }
34
 
35
  .gradio-container {
36
+ background: var(--print-bg);
37
  max-width: 1400px !important;
38
  margin: auto !important;
39
+ color: var(--text-color);
40
+ font-weight: 500;
41
  }
42
 
43
  .output-image, .input-image {
44
  border-radius: var(--border-radius);
45
+ border: var(--print-border);
46
  transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
47
+ background: var(--print-bg);
 
48
  }
49
 
50
  .output-image:hover, .input-image:hover {
51
  transform: scale(1.02) translateY(-2px);
52
+ box-shadow: 0 8px 24px 0 rgba(0, 0, 0, 0.3);
53
  }
54
 
55
  .custom-button {
56
+ background: var(--print-bg);
57
+ border: var(--print-border);
58
+ color: var(--text-color);
59
  padding: 12px 24px;
60
  border-radius: var(--border-radius);
61
  cursor: pointer;
 
63
  font-weight: bold;
64
  text-transform: uppercase;
65
  letter-spacing: 1px;
 
66
  }
67
 
68
  .custom-button:hover {
69
+ transform: translateY(-2px);
70
+ background: #f0f0f0;
71
+ box-shadow: 0 6px 16px rgba(0, 0, 0, 0.2);
72
  }
73
 
74
  .advanced-controls {
75
+ background: var(--print-bg);
76
+ border-radius: 15px;
77
  padding: 25px;
78
+ border: var(--print-border);
79
+ color: var(--text-color);
 
80
  }
81
 
82
  .result-container {
83
+ background: var(--print-bg);
84
+ border-radius: 15px;
85
  padding: 20px;
86
+ border: var(--print-border);
87
+ color: var(--text-color);
 
88
  }
89
 
90
  .interactive-viz {
91
  border-radius: 15px;
92
  overflow: hidden;
93
  transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
94
+ border: var(--print-border);
95
+ background: var(--print-bg);
96
  }
97
 
98
  .interactive-viz:hover {
99
+ transform: translateY(-3px);
100
+ box-shadow: 0 8px 20px rgba(0,0,0,0.2);
101
  }
102
 
103
  .statistics-container {
 
108
  }
109
 
110
  .statistic-card {
111
+ background: var(--print-bg);
112
  padding: 20px;
113
  border-radius: var(--border-radius);
114
  text-align: center;
115
+ border: var(--print-border);
 
 
116
  transition: all 0.3s ease;
117
+ color: var(--text-color);
118
+ font-weight: bold;
119
  }
120
 
121
  .statistic-card:hover {
122
  transform: translateY(-2px);
123
+ box-shadow: 0 6px 16px rgba(0,0,0,0.15);
124
+ }
125
+
126
+ .statistic-card h4 {
127
+ color: var(--text-color);
128
+ font-weight: bold;
129
+ margin-bottom: 10px;
130
+ font-size: 14px;
131
+ }
132
+
133
+ .statistic-card p {
134
+ color: var(--text-color);
135
+ font-weight: bold;
136
+ font-size: 16px;
137
+ margin: 0;
138
  }
139
 
140
  .progress-container {
141
+ background: var(--print-bg);
142
  border-radius: 10px;
143
  padding: 15px;
144
  margin: 10px 0;
145
+ border: var(--print-border);
146
+ color: var(--text-color);
147
  }
148
 
149
  .comparison-slider {
150
+ background: var(--print-bg);
151
  border-radius: 15px;
152
  padding: 20px;
153
+ border: var(--print-border);
154
+ color: var(--text-color);
155
+ }
156
+
157
+ /* 打印优化 */
158
+ @media print {
159
+ .gradio-container {
160
+ background: white !important;
161
+ color: black !important;
162
+ }
163
+
164
+ .output-image, .input-image, .advanced-controls,
165
+ .result-container, .interactive-viz, .statistic-card,
166
+ .progress-container, .comparison-slider {
167
+ background: white !important;
168
+ border: 2px solid black !important;
169
+ color: black !important;
170
+ box-shadow: none !important;
171
+ }
172
+
173
+ .custom-button {
174
+ background: white !important;
175
+ border: 2px solid black !important;
176
+ color: black !important;
177
+ }
178
+
179
+ .statistic-card h4, .statistic-card p {
180
+ color: black !important;
181
+ font-weight: bold !important;
182
+ }
183
+ }
184
+
185
+ /* 高对比度文本 */
186
+ h1, h2, h3, h4, h5, h6, p, span, div, label {
187
+ color: var(--text-color) !important;
188
+ font-weight: 500 !important;
189
+ }
190
+
191
+ .gr-textbox, .gr-dropdown, .gr-radio, .gr-checkbox {
192
+ color: var(--text-color) !important;
193
+ border: 1px solid #000000 !important;
194
  }
195
  """
196
  class ImageProcessor:
 
245
  return image
246
 
247
  def generate_analysis_plots(self, saliency_map):
248
+ """生成分析图表 - 使用原始显著性图而非二值化"""
249
+ plt.style.use('default') # 使用默认样式以确保打印清晰
250
+
251
+ # 设置全局字体和颜色
252
+ plt.rcParams.update({
253
+ 'font.size': 12,
254
+ 'font.weight': 'bold',
255
+ 'axes.labelweight': 'bold',
256
+ 'axes.titleweight': 'bold',
257
+ 'figure.facecolor': 'white',
258
+ 'axes.facecolor': 'white',
259
+ 'text.color': 'black',
260
+ 'axes.labelcolor': 'black',
261
+ 'axes.edgecolor': 'black',
262
+ 'xtick.color': 'black',
263
+ 'ytick.color': 'black'
264
+ })
265
+
266
  fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
267
+ fig.patch.set_facecolor('white')
268
+
269
+ # 直方图 - 使用原始显著性值
270
+ ax1.hist(saliency_map.flatten(), bins=50, color='black', alpha=0.7,
271
+ edgecolor='black', linewidth=1.5)
272
+ ax1.set_title('Histogram of Saliency Distribution', fontsize=14,
273
+ fontweight='bold', color='black', pad=15)
274
+ ax1.set_xlabel('Saliency Value', fontsize=12, fontweight='bold', color='black')
275
+ ax1.set_ylabel('Frequency', fontsize=12, fontweight='bold', color='black')
276
+ ax1.grid(True, alpha=0.5, color='gray', linewidth=1)
277
+ ax1.set_facecolor('white')
278
+
279
  # 添加统计信息
280
  mean_val = np.mean(saliency_map)
281
  median_val = np.median(saliency_map)
282
+ ax1.axvline(mean_val, color='black', linestyle='--', linewidth=2,
283
+ alpha=0.8, label=f'Mean: {mean_val:.3f}')
284
+ ax1.axvline(median_val, color='gray', linestyle=':', linewidth=2,
285
+ alpha=0.8, label=f'Median: {median_val:.3f}')
286
+ ax1.legend(fontsize=10, frameon=True, facecolor='white', edgecolor='black')
287
+
288
  # 累积分布
289
  sorted_vals = np.sort(saliency_map.flatten())
290
  cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals)
291
+ ax2.plot(sorted_vals, cumulative, color='black', linewidth=3)
292
+ ax2.set_title('Cumulative Distribution Function', fontsize=14,
293
+ fontweight='bold', color='black')
294
+ ax2.set_xlabel('Saliency Value', fontsize=12, fontweight='bold', color='black')
295
+ ax2.set_ylabel('Cumulative Probability', fontsize=12, fontweight='bold', color='black')
296
+ ax2.grid(True, alpha=0.5, color='gray', linewidth=1)
297
+ ax2.set_facecolor('white')
298
+
299
  # 箱线图
300
+ box_plot = ax3.boxplot(saliency_map.flatten(), patch_artist=True,
301
+ boxprops=dict(facecolor='lightgray', color='black', linewidth=2),
302
+ whiskerprops=dict(color='black', linewidth=2),
303
+ capprops=dict(color='black', linewidth=2),
304
+ medianprops=dict(color='black', linewidth=3),
305
+ flierprops=dict(marker='o', markerfacecolor='gray',
306
+ markeredgecolor='black', markersize=5))
307
+ ax3.set_title('Boxplot of Saliency Distribution', fontsize=14,
308
+ fontweight='bold', color='black')
309
+ ax3.set_ylabel('Saliency Value', fontsize=12, fontweight='bold', color='black')
310
+ ax3.grid(True, alpha=0.5, color='gray', linewidth=1)
311
+ ax3.set_facecolor('white')
312
+
313
  # 强度剖面(中心线)
314
  center_row = saliency_map[saliency_map.shape[0]//2, :]
315
+ ax4.plot(center_row, color='black', linewidth=3, marker='o',
316
+ markersize=3, markerfacecolor='gray', markeredgecolor='black')
317
+ ax4.set_title('Intensity Profile along Center Line', fontsize=14,
318
+ fontweight='bold', color='black')
319
+ ax4.set_xlabel('Pixel Position', fontsize=12, fontweight='bold', color='black')
320
+ ax4.set_ylabel('Saliency Value', fontsize=12, fontweight='bold', color='black')
321
+ ax4.grid(True, alpha=0.5, color='gray', linewidth=1)
322
+ ax4.set_facecolor('white')
323
+
324
+ # 设置所有轴的颜色
325
+ for ax in [ax1, ax2, ax3, ax4]:
326
+ ax.tick_params(colors='black', labelsize=10)
327
+ for spine in ax.spines.values():
328
+ spine.set_color('black')
329
+ spine.set_linewidth(1.5)
330
+
331
+ plt.tight_layout(pad=3.0)
332
+
333
  # 保存为字节
334
  buf = io.BytesIO()
335
+ plt.savefig(buf, format='png', dpi=300, bbox_inches='tight',
336
+ facecolor='white', edgecolor='black')
337
  buf.seek(0)
338
  img_array = np.array(Image.open(buf))
339
  plt.close()
340
+
341
+ # 重置matplotlib参数
342
+ plt.rcParams.update(plt.rcParamsDefault)
343
+
344
  return img_array
345
 
346
  def quick_process(self, image, threshold=0.5, testsize=256):
 
565
  overlayed = overlay_with_contours
566
 
567
  # 生成统计信息HTML
568
+ # 生成统计信息HTML
569
  if stats:
570
+ stats_html = "<div class='statistics-container' style='background: white; color: black;'>"
571
  for key, value in stats.items():
572
+ stats_html += f"""
573
+ <div class='statistic-card' style='background: white; border: 2px solid black; color: black;'>
574
+ <h4 style='color: black; font-weight: bold; margin-bottom: 10px;'>{key}</h4>
575
+ <p style='color: black; font-weight: bold; font-size: 16px; margin: 0;'>{value}</p>
576
+ </div>
577
+ """
578
  stats_html += "</div>"
579
  else:
580
+ stats_html = "<p style='color: black; font-weight: bold;'>无可用统计信息</p>"
 
 
 
581
 
582
  def create_comparison_view(original, result, slider_value):
583
  """创建滑块对比视图"""