kunkk commited on
Commit
20b1b91
·
verified ·
1 Parent(s): cf39a74

Upload 2 files

Browse files
Files changed (2) hide show
  1. model/CyueNet_models.py +696 -0
  2. model/MobileNetV2.py +123 -0
model/CyueNet_models.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import einops
5
+
6
+ from timm.models.layers import trunc_normal_
7
+
8
+ from einops import rearrange
9
+ import math
10
+
11
+ from model.MobileNetV2 import mobilenet_v2
12
+ from torch.nn import Parameter
13
+
14
+
15
+ class BasicConv2d(nn.Module):
16
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
17
+ super(BasicConv2d, self).__init__()
18
+ self.conv = nn.Conv2d(in_planes, out_planes,
19
+ kernel_size=kernel_size, stride=stride,
20
+ padding=padding, dilation=dilation, bias=False)
21
+ self.bn = nn.BatchNorm2d(out_planes)
22
+ self.relu = nn.ReLU(inplace=True)
23
+
24
+ def forward(self, x):
25
+ x = self.conv(x)
26
+ x = self.bn(x)
27
+ x = self.relu(x)
28
+ return x
29
+
30
+
31
+ class Reduction(nn.Module):
32
+ def __init__(self, in_channel, out_channel):
33
+ super(Reduction, self).__init__()
34
+ self.reduce = nn.Sequential(
35
+ BasicConv2d(in_channel, out_channel, 1),
36
+ BasicConv2d(out_channel, out_channel, 3, padding=1),
37
+ BasicConv2d(out_channel, out_channel, 3, padding=1)
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.reduce(x)
42
+
43
+
44
+ class TopDownLayer(nn.Module):
45
+ def __init__(self, channel):
46
+ super(TopDownLayer, self).__init__()
47
+ self.conv = nn.Sequential(nn.Conv2d(channel, channel, 3, 1, 1, bias=False), nn.BatchNorm2d(channel))
48
+
49
+ self.relu = nn.ReLU()
50
+
51
+ self.channel_compress = nn.Sequential(
52
+ nn.Conv2d(channel * 2, channel, 1, bias=False),
53
+ nn.BatchNorm2d(channel),
54
+ nn.ReLU()
55
+ )
56
+
57
+ def forward(self, x, x2):
58
+ res1 = self.conv(x)
59
+ res1 = self.relu(res1)
60
+
61
+ res1 = F.interpolate(res1, x2.size()[2:], mode='bilinear', align_corners=True)
62
+
63
+ res_cat = torch.cat((res1, x2), dim=1)
64
+
65
+ resl = self.channel_compress(res_cat)
66
+
67
+ return resl
68
+
69
+
70
+ class MultiHeadAttention(nn.Module):
71
+ def __init__(self, head=8, d_model=32, dropout=0.1):
72
+ super(MultiHeadAttention, self).__init__()
73
+ assert (d_model % head == 0)
74
+ self.d_k = d_model // head
75
+ self.head = head
76
+ self.d_model = d_model
77
+ self.linear_query = nn.Linear(d_model, d_model)
78
+ self.linear_key = nn.Linear(d_model, d_model)
79
+ self.linear_value = nn.Linear(d_model, d_model)
80
+
81
+ self.dropout = nn.Dropout(p=dropout)
82
+ self.attn = None
83
+ self.inb = nn.Linear(32, d_model)
84
+
85
+ def self_attention(self, query, key, value, mask=None):
86
+ d_k = query.shape[-1]
87
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
88
+ self_attn = F.softmax(scores, dim=-1)
89
+ # self.attn = self_attn if self.attn is None else self.attn + self_attn
90
+ if self.dropout is not None:
91
+ self_attn = self.dropout(self_attn)
92
+
93
+ return torch.matmul(self_attn, value), self_attn
94
+
95
+ def forward(self, query, key, value, mask=None):
96
+ n_batch = query.size(0)
97
+ query = query.flatten(start_dim=2).permute(0, 2, 1)
98
+
99
+ query = self.inb(query)
100
+
101
+ key = key.flatten(start_dim=2).permute(0, 2, 1)
102
+
103
+ key = self.inb(key)
104
+
105
+ value = value.flatten(start_dim=2).permute(0, 2, 1)
106
+
107
+ value = self.inb(value)
108
+
109
+ x, self.attn = self.self_attention(query, key, value, mask=mask)
110
+
111
+ x = x.permute(0, 2, 1)
112
+ embedding_dim = x.size(-1)
113
+
114
+ d_k = h = int(embedding_dim ** 0.5)
115
+
116
+ x = einops.rearrange(x, 'b n (d_k h) -> b n d_k h', d_k=d_k, h=h)
117
+
118
+ return x
119
+
120
+
121
+ class Upsample(nn.Module):
122
+ def __init__(self):
123
+ super(Upsample, self).__init__()
124
+
125
+ def forward(self, x, x2):
126
+ x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True)
127
+
128
+ return x
129
+
130
+
131
+ class MultiScaleAttention(nn.Module):
132
+ def __init__(self, channel):
133
+ super(MultiScaleAttention, self).__init__()
134
+ # SPatial attention for each branch
135
+ self.attention_branches = nn.ModuleList([SpatialAttention() for _ in range(5)])
136
+ self.upsample = Upsample()
137
+ self.conv_reduce = nn.Conv2d(channel * 6, channel, kernel_size=1)
138
+
139
+ def forward(self, x0, x1, x2, x3, x4, x5):
140
+ x0_att = self.attention_branches[0](x0) * x0
141
+ x1_att = self.attention_branches[0](x1) * x1
142
+ x2_att = self.attention_branches[0](x2) * x2
143
+ x3_att = self.attention_branches[0](x3) * x3
144
+ x4_att = self.attention_branches[0](x4) * x4
145
+ x5_att = self.attention_branches[0](x5) * x5
146
+
147
+ x1_att_up = self.upsample(x1_att, x0)
148
+ x2_att_up = self.upsample(x2_att, x0)
149
+ x3_att_up = self.upsample(x3_att, x0)
150
+ x4_att_up = self.upsample(x4_att, x0)
151
+
152
+ x5_att_up = self.upsample(x5_att, x0)
153
+
154
+ x_cat = torch.cat((x0_att, x1_att_up, x2_att_up, x3_att_up, x4_att_up, x5_att_up), dim=1)
155
+
156
+ x_out = self.conv_reduce(x_cat)
157
+
158
+ return x_out
159
+
160
+
161
+ class Basic2(nn.Module):
162
+ def __init__(self, in_channel, out_channel):
163
+ super(Basic2, self).__init__()
164
+ self.relu = nn.ReLU(True)
165
+ # join
166
+ self.channel_attention = ChannelAttention(out_channel)
167
+ self.channel_attention = SpatialAttention()
168
+ self.branch0 = nn.Sequential(
169
+ BasicConv2d(in_channel, out_channel, 1),
170
+ )
171
+ self.branch1 = nn.Sequential(
172
+ BasicConv2d(in_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
173
+ BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
174
+ BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
175
+ )
176
+ self.branch2 = nn.Sequential(
177
+ BasicConv2d(in_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
178
+ BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
179
+ BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
180
+ )
181
+ self.branch3 = nn.Sequential(
182
+ BasicConv2d(in_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
183
+ BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
184
+ BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
185
+ )
186
+ self.branch4 = nn.Sequential(
187
+ BasicConv2d(in_channel, out_channel, kernel_size=(1, 9), padding=(0, 4)),
188
+ BasicConv2d(out_channel, out_channel, kernel_size=(9, 1), padding=(4, 0)),
189
+ BasicConv2d(out_channel, out_channel, 3, padding=9, dilation=9)
190
+ )
191
+
192
+ self.branch5 = nn.Sequential(
193
+ BasicConv2d(in_channel, out_channel, kernel_size=(1, 11), padding=(0, 5)),
194
+ BasicConv2d(out_channel, out_channel, kernel_size=(11, 1), padding=(5, 0)),
195
+ BasicConv2d(out_channel, out_channel, 3, padding=11, dilation=11)
196
+ )
197
+
198
+ self.multi_scale_attention = MultiScaleAttention(out_channel)
199
+ self.conv_combine = BasicConv2d(in_channel, in_channel, kernel_size=3, padding=1)
200
+
201
+ def forward(self, x):
202
+ x0 = self.branch0(x)
203
+ x1 = self.branch1(x)
204
+ x2 = self.branch2(x)
205
+ x3 = self.branch3(x)
206
+ x4 = self.branch4(x)
207
+ x5 = self.branch5(x)
208
+
209
+ x_att = self.multi_scale_attention(x0, x1, x2, x3, x4, x5)
210
+
211
+ x_combind = self.conv_combine(x_att)
212
+
213
+ x = x_combind + x
214
+ return x
215
+
216
+
217
+ class ChannelAttention(nn.Module):
218
+ def __init__(self, in_planes):
219
+ super(ChannelAttention, self).__init__()
220
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
221
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
222
+
223
+ self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 2, 1, bias=False),
224
+ nn.ReLU(),
225
+ nn.Conv2d(in_planes // 2, in_planes, 1, bias=False))
226
+ self.sigmoid = nn.Sigmoid()
227
+
228
+ def forward(self, x):
229
+ avg_out = self.fc(self.avg_pool(x))
230
+ max_out = self.fc(self.max_pool(x))
231
+ out = avg_out + max_out
232
+ return self.sigmoid(out)
233
+
234
+
235
+ class SpatialAttention(nn.Module):
236
+ def __init__(self, kernel_size=7):
237
+ super(SpatialAttention, self).__init__()
238
+
239
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
240
+ self.sigmoid = nn.Sigmoid()
241
+
242
+ def forward(self, x):
243
+ avg_out = torch.mean(x, dim=1, keepdim=True)
244
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
245
+ x1 = torch.cat([avg_out, max_out], dim=1)
246
+ x2 = self.conv1(x1)
247
+ return self.sigmoid(x2)
248
+
249
+
250
+ class MModule(nn.Module):
251
+ def __init__(self, channel):
252
+ super(MModule, self).__init__()
253
+
254
+ self.basic = Basic2(channel, channel)
255
+ self.SA = SpatialAttention()
256
+ self.CA = ChannelAttention(channel)
257
+
258
+ def forward(self, x):
259
+ x_mix = self.basic(x)
260
+ x_mix = x_mix * self.CA(x_mix) + x_mix
261
+ x_mix1 = x_mix * self.SA(x_mix) + x_mix
262
+
263
+ x_mix1 = x_mix1 + x
264
+
265
+ return x_mix1
266
+
267
+
268
+ class MNodule(nn.Module):
269
+ def __init__(self, channel):
270
+ super(MNodule, self).__init__()
271
+ self.atrconv1 = BasicConv2d(channel, channel, 3, padding=3, dilation=3)
272
+ self.atrconv2 = BasicConv2d(channel, channel, 3, padding=5, dilation=5)
273
+ self.atrconv3 = BasicConv2d(channel, channel, 3, padding=7, dilation=7)
274
+ self.branch1 = nn.Sequential(
275
+ BasicConv2d(channel, channel, 1),
276
+ BasicConv2d(channel, channel, kernel_size=(1, 3), padding=(0, 1)),
277
+ BasicConv2d(channel, channel, kernel_size=(3, 1), padding=(1, 0))
278
+ )
279
+ self.branch2 = nn.Sequential(
280
+ BasicConv2d(channel, channel, 1),
281
+ BasicConv2d(channel, channel, kernel_size=(1, 5), padding=(0, 2)),
282
+ BasicConv2d(channel, channel, kernel_size=(5, 1), padding=(2, 0))
283
+ )
284
+ self.branch3 = nn.Sequential(
285
+ BasicConv2d(channel, channel, 1),
286
+ BasicConv2d(channel, channel, kernel_size=(1, 7), padding=(0, 3)),
287
+ BasicConv2d(channel, channel, kernel_size=(7, 1), padding=(3, 0))
288
+ )
289
+
290
+ self.conv_cat1 = BasicConv2d(2 * channel, channel, 3, padding=1)
291
+ self.conv_cat2 = BasicConv2d(2 * channel, channel, 3, padding=1)
292
+ self.conv_cat3 = BasicConv2d(2 * channel, channel, 3, padding=1)
293
+ self.conv1_1 = BasicConv2d(channel, channel, 1)
294
+
295
+ self.SA = SpatialAttention()
296
+ self.CA = ChannelAttention(channel)
297
+
298
+ self.sal_conv = nn.Sequential(
299
+ BasicConv2d(channel, channel, 3, padding=1),
300
+ BasicConv2d(channel, channel, 3, padding=1)
301
+ )
302
+ self.sigmoid = nn.Sigmoid()
303
+
304
+ def forward(self, x):
305
+ x1 = self.branch1(x)
306
+ x_atr1 = self.atrconv1(x)
307
+ s_mfeb1 = self.conv_cat1(torch.cat((x1, x_atr1), 1)) + x
308
+ x2 = self.branch2(s_mfeb1)
309
+ x_atr2 = self.atrconv2(s_mfeb1)
310
+ s_mfeb2 = self.conv_cat2(torch.cat((x2, x_atr2), 1)) + s_mfeb1 + x
311
+ x3 = self.branch3(s_mfeb2)
312
+ x_atr3 = self.atrconv3(s_mfeb2)
313
+ s_mfeb3 = self.conv_cat3(torch.cat((x3, x_atr3), 1)) + s_mfeb1 + s_mfeb2 + x
314
+ x_m = self.conv1_1(s_mfeb3)
315
+
316
+ x_ca = self.CA(x_m) * x_m
317
+ x_e = self.CA(x_m) * x_m
318
+
319
+ x_mix = self.sal_conv((self.SA(x_ca)) * x_ca) + s_mfeb1 + s_mfeb2 + s_mfeb3 + x
320
+
321
+ return x_mix
322
+
323
+
324
+ class TransBasicConv2d(nn.Module):
325
+ def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1, bias=False):
326
+ super(TransBasicConv2d, self).__init__()
327
+ self.Deconv = nn.ConvTranspose2d(in_planes, out_planes,
328
+ kernel_size=kernel_size, stride=stride,
329
+ padding=padding, dilation=dilation, bias=bias)
330
+ self.bn = nn.BatchNorm2d(out_planes)
331
+ self.relu = nn.ReLU(inplace=True)
332
+
333
+ def forward(self, x):
334
+ x = self.Deconv(x)
335
+ x = self.bn(x)
336
+ x = self.relu(x)
337
+ return x
338
+
339
+
340
+ class features(nn.Module):
341
+ def __init__(self, channel):
342
+ super(features, self).__init__()
343
+ self.conv1 = BasicConv2d(channel, channel, 1)
344
+ self.conv2 = BasicConv2d(channel, channel, 1)
345
+ self.conv3 = BasicConv2d(channel, channel, 1)
346
+ self.conv4 = BasicConv2d(channel, channel, 1)
347
+ self.conv5 = BasicConv2d(channel, channel, 1)
348
+
349
+ def forward(self, x1, x2, x3, x4, x5):
350
+ x1 = self.conv1(x1)
351
+ x2 = self.conv2(x2)
352
+ x3 = self.conv3(x3)
353
+ x4 = self.conv4(x4)
354
+ x5 = self.conv5(x5)
355
+
356
+ return x1, x2, x3, x4, x5
357
+
358
+
359
+ class conv_upsamle(nn.Module):
360
+ def __init__(self, channel):
361
+ super(conv_upsamle, self).__init__()
362
+ self.conv = BasicConv2d(channel, channel, 3, padding=1)
363
+
364
+ def forward(self, x, target):
365
+ if x.size()[2:] != target.size()[2:]:
366
+ x = F.interpolate(x, size=target.size()[2:], mode='bilinear', align_corners=True)
367
+ x = self.conv(x)
368
+ return x
369
+
370
+
371
+ class AP_MP(nn.Module):
372
+ def __init__(self, stride=2):
373
+ super(AP_MP, self).__init__()
374
+ self.sz = stride
375
+ self.gapLayer = nn.AvgPool2d(kernel_size=self.sz, stride=self.sz)
376
+ self.gmpLayer = nn.MaxPool2d(kernel_size=self.sz, stride=self.sz)
377
+
378
+ def forward(self, x1, x2):
379
+ B, C, H, W = x1.size()
380
+ apimg = self.gapLayer(x1)
381
+ mpimg = self.gmpLayer(x2)
382
+ byimg = torch.norm(abs(apimg - mpimg), p=2, dim=1, keepdim=True)
383
+ return byimg
384
+
385
+ class MOM(nn.Module):
386
+ def __init__(self, channel):
387
+ super(MOM, self).__init__()
388
+ self.channel = channel
389
+
390
+ self.conv1 = BasicConv2d(channel, channel, 3, padding=1)
391
+ self.conv2 = BasicConv2d(channel, channel, 3, padding=1)
392
+
393
+ self.CA1 = ChannelAttention(self.channel)
394
+ self.CA2 = ChannelAttention(self.channel)
395
+ self.SA1 = SpatialAttention()
396
+ self.SA2 = SpatialAttention()
397
+
398
+ self.glbamp = AP_MP()
399
+ self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
400
+ self.conv = BasicConv2d(channel * 2 , channel, kernel_size=1, stride=1)
401
+
402
+ self.upSA = SpatialAttention()
403
+
404
+ def forward(self, x1, x2):
405
+ x1 = self.conv1(x1)
406
+ x2 = self.conv2(x2)
407
+
408
+ x1 = x1 + x1 * self.CA1(x1)
409
+ x2 = x2 + x2 * self.CA2(x2)
410
+
411
+ nx1 = x1 + x1 * self.SA2(x2)
412
+ nx2 = x2 + x2 * self.SA1(x1)
413
+
414
+ res = self.conv(torch.cat([nx1, nx2], dim=1))
415
+
416
+ res = res + x1
417
+ edg = res
418
+ ske = res
419
+
420
+ return res, edg, ske
421
+
422
+
423
+ class AFM(nn.Module):
424
+ def __init__(self, channel):
425
+ super(AFM, self).__init__()
426
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
427
+ self.sigmoid = nn.Sigmoid()
428
+ self.conv1_1 = nn.Conv2d(channel, channel, kernel_size=1)
429
+
430
+ self.ca1 = ChannelAttention(channel)
431
+ self.ca2 = ChannelAttention(channel)
432
+
433
+ self.sa = SpatialAttention()
434
+
435
+ self.sal_conv = nn.Sequential(
436
+ BasicConv2d(channel, channel, 3, padding=1),
437
+ BasicConv2d(channel, channel, 3, padding=1)
438
+ )
439
+ self.sigmoid = nn.Sigmoid()
440
+
441
+ def forward(self, x1, x2):
442
+ x2 = self.sigmoid(self.max_pool(x2))
443
+ xb = x2 * x1
444
+ x = self.conv1_1(xb)
445
+ x_c = self.ca1(x) * x
446
+ x_d = self.ca2(x) * x
447
+
448
+ s_mea = self.sal_conv((self.sa(x_c)) * x_c) + x1 + x2 + xb
449
+ ske = s_mea
450
+ e_pred = s_mea
451
+ return s_mea, e_pred, ske
452
+
453
+ class DummyMOM(nn.Module):
454
+ def __init__(self, channel):
455
+ super(DummyMOM, self).__init__()
456
+ self.conv1 = nn.Identity() # 保持输入输出一致
457
+ self.conv2 = nn.Identity() # 保持输入输出一致
458
+
459
+ # 调整为64个输入通道
460
+ self.conv = nn.Conv2d(64, 32, kernel_size=1) # 1x1卷积调整通道数
461
+
462
+ def forward(self, x1, x2):
463
+ # 先做拼接,然后调整通道数为32
464
+ res = self.conv(torch.cat([x1, x2], dim=1))
465
+ edg = res
466
+ ske = res
467
+ return res, edg, ske
468
+
469
+ class YUEM(nn.Module):
470
+ def __init__(self, channel):
471
+ super(YUEM, self).__init__()
472
+ self.channel = channel
473
+ self.m1 = MModule(self.channel)
474
+ self.m2 = MNodule(self.channel)
475
+ self.mha = MultiHeadAttention(channel)
476
+
477
+ def forward(self, x1, x2):
478
+ x1 = self.m1(x1)
479
+ x21 = self.m2(x2)
480
+
481
+ res = self.mha(x1, x21, x2)
482
+
483
+ edg = res
484
+ ske = res
485
+ return res, edg, ske
486
+
487
+
488
+ class MTG(nn.Module):
489
+ def __init__(self, channel):
490
+ super(MTG, self).__init__()
491
+ self.ccs = nn.ModuleList([nn.Sequential(
492
+ BasicConv2d(3 * channel, channel, kernel_size=3, padding=1),
493
+ BasicConv2d(channel, channel, kernel_size=3, padding=1)
494
+ ) for i in range(5)])
495
+
496
+ def forward(self, x_sal, x_edg, x_ske):
497
+
498
+ x_combined = torch.cat((x_sal, x_edg,x_ske), dim=1)
499
+
500
+ x_sal_n = self.ccs[0](x_combined)
501
+
502
+ return x_sal_n
503
+ class MMS(nn.Module):
504
+ def __init__(self, pretrained=True, channel=32):
505
+ super(MMS, self).__init__()
506
+ self.backbone = mobilenet_v2(pretrained)
507
+
508
+ self.Translayer1 = Reduction(16, channel)
509
+ self.Translayer2 = Reduction(24, channel)
510
+ self.Translayer3 = Reduction(32, channel)
511
+ self.Translayer4 = Reduction(96, channel)
512
+ self.Translayer5 = Reduction(320, channel)
513
+
514
+ self.trans_conv1 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2,
515
+ padding=0, dilation=1, bias=False)
516
+ self.trans_conv2 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2,
517
+ padding=0, dilation=1, bias=False)
518
+ self.trans_conv3 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2,
519
+ padding=0, dilation=1, bias=False)
520
+ self.trans_conv4 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2,
521
+ padding=0, dilation=1, bias=False)
522
+
523
+ self.mom = MOM(channel)
524
+ # self.mom = DummyMOM(channel)
525
+ self.afm = AFM(channel)
526
+ # self.afm = DummyMOM(channel)
527
+ self.yuem = YUEM(channel)
528
+ # self.yuem = DummyMOM(channel)
529
+
530
+ self.sigmoid = nn.Sigmoid()
531
+
532
+ self.sal_features = features(channel)
533
+ self.edg_features = features(channel)
534
+ self.ske_features = features(channel)
535
+ self.MTG = MTG(channel)
536
+
537
+ self.ccs = nn.ModuleList([nn.Sequential(
538
+ BasicConv2d(3 * channel, channel, kernel_size=3, padding=1),
539
+ BasicConv2d(channel, channel, kernel_size=3, padding=1)
540
+ ) for i in range(5)])
541
+ self.cme = nn.ModuleList([nn.Sequential(
542
+ BasicConv2d(3 * channel, channel, kernel_size=3, padding=1),
543
+ BasicConv2d(channel, channel, kernel_size=3, padding=1)
544
+ ) for i in range(5)])
545
+ self.cms = nn.ModuleList([nn.Sequential(
546
+ BasicConv2d(3 * channel, channel, kernel_size=3, padding=1),
547
+ BasicConv2d(channel, channel, kernel_size=3, padding=1)
548
+ ) for i in range(5)])
549
+
550
+ self.conv_cats = nn.ModuleList([nn.Sequential(
551
+ BasicConv2d(2 * channel, channel, kernel_size=3, padding=1),
552
+ BasicConv2d(channel, channel, kernel_size=3, padding=1)
553
+ ) for i in range(12)])
554
+
555
+
556
+ self.cus = nn.ModuleList([conv_upsamle(channel) for i in range(12)])
557
+ self.prediction = nn.ModuleList([
558
+ nn.Sequential(
559
+ BasicConv2d(channel, channel, kernel_size=3, padding=1),
560
+ nn.Conv2d(channel, 1, kernel_size=1)
561
+ ) for i in range(3)
562
+ ])
563
+
564
+ self.S1 = nn.Sequential(
565
+ BasicConv2d(channel, channel, 3, padding=1),
566
+ nn.Conv2d(channel, 1, 1)
567
+ )
568
+ self.S2 = nn.Sequential(
569
+ BasicConv2d(channel, channel, 3, padding=1),
570
+ nn.Conv2d(channel, 1, 1)
571
+ )
572
+ self.S3 = nn.Sequential(
573
+ BasicConv2d(channel, channel, 3, padding=1),
574
+ nn.Conv2d(channel, 1, 1)
575
+ )
576
+ self.S4 = nn.Sequential(
577
+ BasicConv2d(channel, channel, 3, padding=1),
578
+ nn.Conv2d(channel, 1, 1)
579
+ )
580
+ self.S5 = nn.Sequential(
581
+ BasicConv2d(channel, channel, 3, padding=1),
582
+ nn.Conv2d(channel, 1, 1)
583
+ )
584
+
585
+
586
+ def forward(self, x):
587
+ size = x.size()[2:]
588
+ conv1, conv2, conv3, conv4, conv5 = self.backbone(x)
589
+
590
+ conv1 = self.Translayer1(conv1)
591
+ conv2 = self.Translayer2(conv2)
592
+ conv3 = self.Translayer3(conv3)
593
+ conv4 = self.Translayer4(conv4)
594
+ conv5 = self.Translayer5(conv5)
595
+
596
+ rgc5, edg5, ske5 = self.afm(conv5, conv5)
597
+ rgc4, edg4, ske4 = self.yuem(conv4, self.trans_conv4(conv5))
598
+ rgc3, edg3, ske3 = self.yuem(conv3, self.trans_conv3(conv4))
599
+ rgc2, edg2, ske2 = self.mom(conv2, self.trans_conv2(conv3))
600
+ rgc1, edg1, ske1 = self.mom(conv1, self.trans_conv1(conv2))
601
+
602
+
603
+ x_sal1, x_sal2, x_sal3, x_sal4, x_sal5 = self.sal_features(rgc1, rgc2, rgc3, rgc4, rgc5)
604
+ x_edg1, x_edg2, x_edg3, x_edg4, x_edg5 = self.edg_features(edg1, edg2, edg3, edg4, edg5)
605
+ x_ske1, x_ske2, x_ske3, x_ske4, x_ske5 = self.ske_features(ske1, ske2, ske3, ske4, ske5)
606
+
607
+
608
+ x_sal5_n = self.ccs[0](torch.cat((x_sal5, x_edg5, x_sal5), 1)) + x_sal5
609
+ x_edg5_n = self.cme[0](torch.cat((x_sal5, x_edg5, x_sal5), 1)) + x_edg5
610
+ x_ske5_n = self.cms[0](torch.cat((x_sal5, x_edg5, x_ske5), 1)) + x_ske5
611
+
612
+
613
+ x_sal4 = self.conv_cats[0](torch.cat((x_sal4, self.cus[0](x_sal5_n, x_sal4)), 1))
614
+ x_edg4 = self.conv_cats[1](torch.cat((x_edg4, self.cus[1](x_edg5_n, x_edg4)), 1))
615
+ x_ske4 = self.conv_cats[2](torch.cat((x_ske4, self.cus[2](x_ske5_n, x_ske4)), 1))
616
+
617
+ x_sal4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_sal4
618
+ x_edg4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_edg4
619
+ x_ske4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_ske4
620
+
621
+
622
+ x_sal3 = self.conv_cats[3](torch.cat((x_sal3, self.cus[3](x_sal4_n, x_sal3)), 1))
623
+ x_edg3 = self.conv_cats[4](torch.cat((x_edg3, self.cus[4](x_edg4_n, x_edg3)), 1))
624
+ x_ske3 = self.conv_cats[5](torch.cat((x_ske3, self.cus[5](x_ske4_n, x_ske3)), 1))
625
+
626
+
627
+ x_sal3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_sal3
628
+ x_edg3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_edg3
629
+ x_ske3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_ske3
630
+
631
+
632
+
633
+ x_sal2 = self.conv_cats[6](torch.cat((x_sal2, self.cus[6](x_sal3_n, x_sal2)), 1))
634
+ x_edg2 = self.conv_cats[7](torch.cat((x_edg2, self.cus[7](x_edg3_n, x_edg2)), 1))
635
+ x_ske2 = self.conv_cats[8](torch.cat((x_ske2, self.cus[8](x_ske3_n, x_ske2)), 1))
636
+
637
+
638
+ x_sal2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_sal2
639
+ x_edg2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_edg2
640
+ x_ske2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_ske2
641
+
642
+ x_sal1 = self.conv_cats[9](torch.cat((x_sal1, self.cus[9](x_sal2_n, x_sal1)), 1))
643
+ x_edg1 = self.conv_cats[10](torch.cat((x_edg1, self.cus[10](x_edg2_n, x_edg1)), 1))
644
+ x_ske1 = self.conv_cats[11](torch.cat((x_ske1, self.cus[11](x_ske2_n, x_ske1)), 1))
645
+
646
+
647
+ x_sal1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_sal1
648
+ x_edg1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_edg1
649
+ x_ske1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_ske1
650
+
651
+ sal_out = self.prediction[0](x_sal1_n)
652
+ edg_out = self.prediction[1](x_edg1_n)
653
+ ske_out = self.prediction[2](x_ske1_n)
654
+
655
+ x_sal2_n = self.prediction[0](x_sal2_n)
656
+ x_edg2_n = self.prediction[1](x_edg2_n)
657
+ x_ske2_n = self.prediction[2](x_ske2_n)
658
+ x_sal3_n = self.prediction[0](x_sal3_n)
659
+ x_edg3_n = self.prediction[1](x_edg3_n)
660
+ x_ske3_n = self.prediction[2](x_ske3_n)
661
+
662
+ x_sal4_n = self.prediction[0](x_sal4_n)
663
+ x_edg4_n = self.prediction[1](x_edg4_n)
664
+ x_ske4_n = self.prediction[2](x_ske4_n)
665
+
666
+ x_sal5_n = self.prediction[0](x_sal5_n)
667
+ x_edg5_n = self.prediction[1](x_edg5_n)
668
+ x_ske5_n = self.prediction[2](x_ske5_n)
669
+
670
+ sal_out = F.interpolate(sal_out, size=size, mode='bilinear', align_corners=True)
671
+ edg_out = F.interpolate(edg_out, size=size, mode='bilinear', align_corners=True)
672
+ ske_out = F.interpolate(ske_out, size=size, mode='bilinear', align_corners=True)
673
+ sal2 = F.interpolate(x_sal2_n, size=size, mode='bilinear', align_corners=True)
674
+ edg2 = F.interpolate(x_edg2_n, size=size, mode='bilinear', align_corners=True)
675
+ ske2 = F.interpolate(x_ske2_n, size=size, mode='bilinear', align_corners=True)
676
+ sal3 = F.interpolate(x_sal3_n, size=size, mode='bilinear', align_corners=True)
677
+ edg3 = F.interpolate(x_edg3_n, size=size, mode='bilinear', align_corners=True)
678
+ ske3 = F.interpolate(x_ske3_n, size=size, mode='bilinear', align_corners=True)
679
+ sal4 = F.interpolate(x_sal4_n, size=size, mode='bilinear', align_corners=True)
680
+ edg4 = F.interpolate(x_edg4_n, size=size, mode='bilinear', align_corners=True)
681
+ ske4 = F.interpolate(x_ske4_n, size=size, mode='bilinear', align_corners=True)
682
+ sal5 = F.interpolate(x_sal5_n, size=size, mode='bilinear', align_corners=True)
683
+ edg5 = F.interpolate(x_edg5_n, size=size, mode='bilinear', align_corners=True)
684
+ ske5 = F.interpolate(x_ske5_n, size=size, mode='bilinear', align_corners=True)
685
+
686
+
687
+ return x_sal1_n, sal_out, self.sigmoid(sal_out), edg_out, self.sigmoid(edg_out), sal2, edg2, self.sigmoid(
688
+ sal2), self.sigmoid(edg2), sal3, edg3, self.sigmoid(sal3), self.sigmoid(edg3), sal4, edg4, self.sigmoid(
689
+ sal4), self.sigmoid(edg4), sal5, edg5, self.sigmoid(sal5), self.sigmoid(edg5), ske_out, self.sigmoid(
690
+ ske_out), ske2, self.sigmoid(ske2), ske3, self.sigmoid(ske3), ske4, self.sigmoid(ske4), ske5, self.sigmoid(
691
+ ske5)
692
+ # return x_sal1_n, sal_out, self.sigmoid(sal_out), edg_out, self.sigmoid(edg_out), sal2, edg2, self.sigmoid(
693
+ # sal2), self.sigmoid(edg2), sal3, edg3, self.sigmoid(sal3), self.sigmoid(edg3), sal4, edg4, self.sigmoid(
694
+ # sal4), self.sigmoid(edg4), sal5, edg5, self.sigmoid(sal5), self.sigmoid(edg5)
695
+
696
+
model/MobileNetV2.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+ try:
5
+ from torchvision.models.utils import load_state_dict_from_url # torchvision 0.4+
6
+ except ModuleNotFoundError:
7
+ try:
8
+ from torch.hub import load_state_dict_from_url # torch 1.x
9
+ except ModuleNotFoundError:
10
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url # torch 0.4.1
11
+
12
+ model_urls = {
13
+ 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
14
+ }
15
+
16
+
17
+ class ConvBNReLU(nn.Sequential):
18
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1):
19
+ padding = (kernel_size - 1) // 2
20
+ if dilation != 1:
21
+ padding = dilation
22
+ super(ConvBNReLU, self).__init__(
23
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, dilation=dilation,
24
+ bias=False),
25
+ nn.BatchNorm2d(out_planes),
26
+ nn.ReLU6(inplace=True)
27
+ )
28
+
29
+
30
+ class InvertedResidual(nn.Module):
31
+ def __init__(self, inp, oup, stride, expand_ratio, dilation=1):
32
+ super(InvertedResidual, self).__init__()
33
+ self.stride = stride
34
+ assert stride in [1, 2]
35
+
36
+ hidden_dim = int(round(inp * expand_ratio))
37
+ self.use_res_connect = self.stride == 1 and inp == oup
38
+
39
+ layers = []
40
+ if expand_ratio != 1:
41
+ # pw
42
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
43
+ layers.extend([
44
+ # dw
45
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, dilation=dilation),
46
+ # pw-linear
47
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
48
+ nn.BatchNorm2d(oup),
49
+ ])
50
+ self.conv = nn.Sequential(*layers)
51
+
52
+ def forward(self, x):
53
+ if self.use_res_connect:
54
+ return x + self.conv(x)
55
+ else:
56
+ return self.conv(x)
57
+
58
+
59
+ class MobileNetV2(nn.Module):
60
+ def __init__(self, pretrained=None, num_classes=1000, width_mult=1.0):
61
+ super(MobileNetV2, self).__init__()
62
+ block = InvertedResidual
63
+ input_channel = 32
64
+ last_channel = 1280
65
+ inverted_residual_setting = [
66
+ # t, c, n, s, d
67
+ [1, 16, 1, 1, 1], # conv1 112*112*16
68
+ [6, 24, 2, 2, 1], # conv2 56*56*24
69
+ [6, 32, 3, 2, 1], # conv3 28*28*32
70
+ [6, 64, 4, 2, 1],
71
+ [6, 96, 3, 1, 1], # conv4 14*14*96
72
+ [6, 160, 3, 2, 1],
73
+ [6, 320, 1, 1, 1], # conv5 7*7*320
74
+ ]
75
+
76
+ # building first layer
77
+ input_channel = int(input_channel * width_mult)
78
+ self.last_channel = int(last_channel * max(1.0, width_mult))
79
+ features = [ConvBNReLU(3, input_channel, stride=2)]
80
+ # building inverted residual blocks
81
+ for t, c, n, s, d in inverted_residual_setting:
82
+ output_channel = int(c * width_mult)
83
+ for i in range(n):
84
+ stride = s if i == 0 else 1
85
+ dilation = d if i == 0 else 1
86
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d))
87
+ input_channel = output_channel
88
+ # building last several layers
89
+ features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
90
+ # make it nn.Sequential
91
+ self.features = nn.Sequential(*features)
92
+
93
+ # weight initialization
94
+ for m in self.modules():
95
+ if isinstance(m, nn.Conv2d):
96
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
97
+ if m.bias is not None:
98
+ nn.init.zeros_(m.bias)
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ nn.init.ones_(m.weight)
101
+ nn.init.zeros_(m.bias)
102
+ elif isinstance(m, nn.Linear):
103
+ nn.init.normal_(m.weight, 0, 0.01)
104
+ nn.init.zeros_(m.bias)
105
+
106
+ def forward(self, x):
107
+ res = []
108
+ for idx, m in enumerate(self.features):
109
+ x = m(x)
110
+ if idx in [1, 3, 6, 13, 17]:
111
+ res.append(x)
112
+ return res
113
+
114
+
115
+ def mobilenet_v2(pretrained=True, progress=True, **kwargs):
116
+ model = MobileNetV2(**kwargs)
117
+ if pretrained:
118
+ state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
119
+ progress=progress)
120
+ print("loading imagenet pretrained mobilenetv2")
121
+ model.load_state_dict(state_dict, strict=False)
122
+ print("loaded imagenet pretrained mobilenetv2")
123
+ return model