Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import einops | |
| from timm.models.layers import trunc_normal_ | |
| from einops import rearrange | |
| import math | |
| from model.MobileNetV2 import mobilenet_v2 | |
| from torch.nn import Parameter | |
| class BasicConv2d(nn.Module): | |
| def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): | |
| super(BasicConv2d, self).__init__() | |
| self.conv = nn.Conv2d(in_planes, out_planes, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=padding, dilation=dilation, bias=False) | |
| self.bn = nn.BatchNorm2d(out_planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.bn(x) | |
| x = self.relu(x) | |
| return x | |
| class Reduction(nn.Module): | |
| def __init__(self, in_channel, out_channel): | |
| super(Reduction, self).__init__() | |
| self.reduce = nn.Sequential( | |
| BasicConv2d(in_channel, out_channel, 1), | |
| BasicConv2d(out_channel, out_channel, 3, padding=1), | |
| BasicConv2d(out_channel, out_channel, 3, padding=1) | |
| ) | |
| def forward(self, x): | |
| return self.reduce(x) | |
| class TopDownLayer(nn.Module): | |
| def __init__(self, channel): | |
| super(TopDownLayer, self).__init__() | |
| self.conv = nn.Sequential(nn.Conv2d(channel, channel, 3, 1, 1, bias=False), nn.BatchNorm2d(channel)) | |
| self.relu = nn.ReLU() | |
| self.channel_compress = nn.Sequential( | |
| nn.Conv2d(channel * 2, channel, 1, bias=False), | |
| nn.BatchNorm2d(channel), | |
| nn.ReLU() | |
| ) | |
| def forward(self, x, x2): | |
| res1 = self.conv(x) | |
| res1 = self.relu(res1) | |
| res1 = F.interpolate(res1, x2.size()[2:], mode='bilinear', align_corners=True) | |
| res_cat = torch.cat((res1, x2), dim=1) | |
| resl = self.channel_compress(res_cat) | |
| return resl | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, head=8, d_model=32, dropout=0.1): | |
| super(MultiHeadAttention, self).__init__() | |
| assert (d_model % head == 0) | |
| self.d_k = d_model // head | |
| self.head = head | |
| self.d_model = d_model | |
| self.linear_query = nn.Linear(d_model, d_model) | |
| self.linear_key = nn.Linear(d_model, d_model) | |
| self.linear_value = nn.Linear(d_model, d_model) | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.attn = None | |
| self.inb = nn.Linear(32, d_model) | |
| def self_attention(self, query, key, value, mask=None): | |
| d_k = query.shape[-1] | |
| scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) | |
| self_attn = F.softmax(scores, dim=-1) | |
| # self.attn = self_attn if self.attn is None else self.attn + self_attn | |
| if self.dropout is not None: | |
| self_attn = self.dropout(self_attn) | |
| return torch.matmul(self_attn, value), self_attn | |
| def forward(self, query, key, value, mask=None): | |
| n_batch = query.size(0) | |
| query = query.flatten(start_dim=2).permute(0, 2, 1) | |
| query = self.inb(query) | |
| key = key.flatten(start_dim=2).permute(0, 2, 1) | |
| key = self.inb(key) | |
| value = value.flatten(start_dim=2).permute(0, 2, 1) | |
| value = self.inb(value) | |
| x, self.attn = self.self_attention(query, key, value, mask=mask) | |
| x = x.permute(0, 2, 1) | |
| embedding_dim = x.size(-1) | |
| d_k = h = int(embedding_dim ** 0.5) | |
| x = einops.rearrange(x, 'b n (d_k h) -> b n d_k h', d_k=d_k, h=h) | |
| return x | |
| class Upsample(nn.Module): | |
| def __init__(self): | |
| super(Upsample, self).__init__() | |
| def forward(self, x, x2): | |
| x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True) | |
| return x | |
| class MultiScaleAttention(nn.Module): | |
| def __init__(self, channel): | |
| super(MultiScaleAttention, self).__init__() | |
| # SPatial attention for each branch | |
| self.attention_branches = nn.ModuleList([SpatialAttention() for _ in range(5)]) | |
| self.upsample = Upsample() | |
| self.conv_reduce = nn.Conv2d(channel * 6, channel, kernel_size=1) | |
| def forward(self, x0, x1, x2, x3, x4, x5): | |
| x0_att = self.attention_branches[0](x0) * x0 | |
| x1_att = self.attention_branches[0](x1) * x1 | |
| x2_att = self.attention_branches[0](x2) * x2 | |
| x3_att = self.attention_branches[0](x3) * x3 | |
| x4_att = self.attention_branches[0](x4) * x4 | |
| x5_att = self.attention_branches[0](x5) * x5 | |
| x1_att_up = self.upsample(x1_att, x0) | |
| x2_att_up = self.upsample(x2_att, x0) | |
| x3_att_up = self.upsample(x3_att, x0) | |
| x4_att_up = self.upsample(x4_att, x0) | |
| x5_att_up = self.upsample(x5_att, x0) | |
| x_cat = torch.cat((x0_att, x1_att_up, x2_att_up, x3_att_up, x4_att_up, x5_att_up), dim=1) | |
| x_out = self.conv_reduce(x_cat) | |
| return x_out | |
| class Basic2(nn.Module): | |
| def __init__(self, in_channel, out_channel): | |
| super(Basic2, self).__init__() | |
| self.relu = nn.ReLU(True) | |
| # join | |
| self.channel_attention = ChannelAttention(out_channel) | |
| self.channel_attention = SpatialAttention() | |
| self.branch0 = nn.Sequential( | |
| BasicConv2d(in_channel, out_channel, 1), | |
| ) | |
| self.branch1 = nn.Sequential( | |
| BasicConv2d(in_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)), | |
| BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)), | |
| BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3) | |
| ) | |
| self.branch2 = nn.Sequential( | |
| BasicConv2d(in_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)), | |
| BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)), | |
| BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5) | |
| ) | |
| self.branch3 = nn.Sequential( | |
| BasicConv2d(in_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)), | |
| BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)), | |
| BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7) | |
| ) | |
| self.branch4 = nn.Sequential( | |
| BasicConv2d(in_channel, out_channel, kernel_size=(1, 9), padding=(0, 4)), | |
| BasicConv2d(out_channel, out_channel, kernel_size=(9, 1), padding=(4, 0)), | |
| BasicConv2d(out_channel, out_channel, 3, padding=9, dilation=9) | |
| ) | |
| self.branch5 = nn.Sequential( | |
| BasicConv2d(in_channel, out_channel, kernel_size=(1, 11), padding=(0, 5)), | |
| BasicConv2d(out_channel, out_channel, kernel_size=(11, 1), padding=(5, 0)), | |
| BasicConv2d(out_channel, out_channel, 3, padding=11, dilation=11) | |
| ) | |
| self.multi_scale_attention = MultiScaleAttention(out_channel) | |
| self.conv_combine = BasicConv2d(in_channel, in_channel, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| x0 = self.branch0(x) | |
| x1 = self.branch1(x) | |
| x2 = self.branch2(x) | |
| x3 = self.branch3(x) | |
| x4 = self.branch4(x) | |
| x5 = self.branch5(x) | |
| x_att = self.multi_scale_attention(x0, x1, x2, x3, x4, x5) | |
| x_combind = self.conv_combine(x_att) | |
| x = x_combind + x | |
| return x | |
| class ChannelAttention(nn.Module): | |
| def __init__(self, in_planes): | |
| super(ChannelAttention, self).__init__() | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.max_pool = nn.AdaptiveMaxPool2d(1) | |
| self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 2, 1, bias=False), | |
| nn.ReLU(), | |
| nn.Conv2d(in_planes // 2, in_planes, 1, bias=False)) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| avg_out = self.fc(self.avg_pool(x)) | |
| max_out = self.fc(self.max_pool(x)) | |
| out = avg_out + max_out | |
| return self.sigmoid(out) | |
| class SpatialAttention(nn.Module): | |
| def __init__(self, kernel_size=7): | |
| super(SpatialAttention, self).__init__() | |
| self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| avg_out = torch.mean(x, dim=1, keepdim=True) | |
| max_out, _ = torch.max(x, dim=1, keepdim=True) | |
| x1 = torch.cat([avg_out, max_out], dim=1) | |
| x2 = self.conv1(x1) | |
| return self.sigmoid(x2) | |
| class MModule(nn.Module): | |
| def __init__(self, channel): | |
| super(MModule, self).__init__() | |
| self.basic = Basic2(channel, channel) | |
| self.SA = SpatialAttention() | |
| self.CA = ChannelAttention(channel) | |
| def forward(self, x): | |
| x_mix = self.basic(x) | |
| x_mix = x_mix * self.CA(x_mix) + x_mix | |
| x_mix1 = x_mix * self.SA(x_mix) + x_mix | |
| x_mix1 = x_mix1 + x | |
| return x_mix1 | |
| class MNodule(nn.Module): | |
| def __init__(self, channel): | |
| super(MNodule, self).__init__() | |
| self.atrconv1 = BasicConv2d(channel, channel, 3, padding=3, dilation=3) | |
| self.atrconv2 = BasicConv2d(channel, channel, 3, padding=5, dilation=5) | |
| self.atrconv3 = BasicConv2d(channel, channel, 3, padding=7, dilation=7) | |
| self.branch1 = nn.Sequential( | |
| BasicConv2d(channel, channel, 1), | |
| BasicConv2d(channel, channel, kernel_size=(1, 3), padding=(0, 1)), | |
| BasicConv2d(channel, channel, kernel_size=(3, 1), padding=(1, 0)) | |
| ) | |
| self.branch2 = nn.Sequential( | |
| BasicConv2d(channel, channel, 1), | |
| BasicConv2d(channel, channel, kernel_size=(1, 5), padding=(0, 2)), | |
| BasicConv2d(channel, channel, kernel_size=(5, 1), padding=(2, 0)) | |
| ) | |
| self.branch3 = nn.Sequential( | |
| BasicConv2d(channel, channel, 1), | |
| BasicConv2d(channel, channel, kernel_size=(1, 7), padding=(0, 3)), | |
| BasicConv2d(channel, channel, kernel_size=(7, 1), padding=(3, 0)) | |
| ) | |
| self.conv_cat1 = BasicConv2d(2 * channel, channel, 3, padding=1) | |
| self.conv_cat2 = BasicConv2d(2 * channel, channel, 3, padding=1) | |
| self.conv_cat3 = BasicConv2d(2 * channel, channel, 3, padding=1) | |
| self.conv1_1 = BasicConv2d(channel, channel, 1) | |
| self.SA = SpatialAttention() | |
| self.CA = ChannelAttention(channel) | |
| self.sal_conv = nn.Sequential( | |
| BasicConv2d(channel, channel, 3, padding=1), | |
| BasicConv2d(channel, channel, 3, padding=1) | |
| ) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| x1 = self.branch1(x) | |
| x_atr1 = self.atrconv1(x) | |
| s_mfeb1 = self.conv_cat1(torch.cat((x1, x_atr1), 1)) + x | |
| x2 = self.branch2(s_mfeb1) | |
| x_atr2 = self.atrconv2(s_mfeb1) | |
| s_mfeb2 = self.conv_cat2(torch.cat((x2, x_atr2), 1)) + s_mfeb1 + x | |
| x3 = self.branch3(s_mfeb2) | |
| x_atr3 = self.atrconv3(s_mfeb2) | |
| s_mfeb3 = self.conv_cat3(torch.cat((x3, x_atr3), 1)) + s_mfeb1 + s_mfeb2 + x | |
| x_m = self.conv1_1(s_mfeb3) | |
| x_ca = self.CA(x_m) * x_m | |
| x_e = self.CA(x_m) * x_m | |
| x_mix = self.sal_conv((self.SA(x_ca)) * x_ca) + s_mfeb1 + s_mfeb2 + s_mfeb3 + x | |
| return x_mix | |
| class TransBasicConv2d(nn.Module): | |
| def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1, bias=False): | |
| super(TransBasicConv2d, self).__init__() | |
| self.Deconv = nn.ConvTranspose2d(in_planes, out_planes, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=padding, dilation=dilation, bias=bias) | |
| self.bn = nn.BatchNorm2d(out_planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| x = self.Deconv(x) | |
| x = self.bn(x) | |
| x = self.relu(x) | |
| return x | |
| class features(nn.Module): | |
| def __init__(self, channel): | |
| super(features, self).__init__() | |
| self.conv1 = BasicConv2d(channel, channel, 1) | |
| self.conv2 = BasicConv2d(channel, channel, 1) | |
| self.conv3 = BasicConv2d(channel, channel, 1) | |
| self.conv4 = BasicConv2d(channel, channel, 1) | |
| self.conv5 = BasicConv2d(channel, channel, 1) | |
| def forward(self, x1, x2, x3, x4, x5): | |
| x1 = self.conv1(x1) | |
| x2 = self.conv2(x2) | |
| x3 = self.conv3(x3) | |
| x4 = self.conv4(x4) | |
| x5 = self.conv5(x5) | |
| return x1, x2, x3, x4, x5 | |
| class conv_upsamle(nn.Module): | |
| def __init__(self, channel): | |
| super(conv_upsamle, self).__init__() | |
| self.conv = BasicConv2d(channel, channel, 3, padding=1) | |
| def forward(self, x, target): | |
| if x.size()[2:] != target.size()[2:]: | |
| x = F.interpolate(x, size=target.size()[2:], mode='bilinear', align_corners=True) | |
| x = self.conv(x) | |
| return x | |
| class AP_MP(nn.Module): | |
| def __init__(self, stride=2): | |
| super(AP_MP, self).__init__() | |
| self.sz = stride | |
| self.gapLayer = nn.AvgPool2d(kernel_size=self.sz, stride=self.sz) | |
| self.gmpLayer = nn.MaxPool2d(kernel_size=self.sz, stride=self.sz) | |
| def forward(self, x1, x2): | |
| B, C, H, W = x1.size() | |
| apimg = self.gapLayer(x1) | |
| mpimg = self.gmpLayer(x2) | |
| byimg = torch.norm(abs(apimg - mpimg), p=2, dim=1, keepdim=True) | |
| return byimg | |
| class MOM(nn.Module): | |
| def __init__(self, channel): | |
| super(MOM, self).__init__() | |
| self.channel = channel | |
| self.conv1 = BasicConv2d(channel, channel, 3, padding=1) | |
| self.conv2 = BasicConv2d(channel, channel, 3, padding=1) | |
| self.CA1 = ChannelAttention(self.channel) | |
| self.CA2 = ChannelAttention(self.channel) | |
| self.SA1 = SpatialAttention() | |
| self.SA2 = SpatialAttention() | |
| self.glbamp = AP_MP() | |
| self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| self.conv = BasicConv2d(channel * 2 , channel, kernel_size=1, stride=1) | |
| self.upSA = SpatialAttention() | |
| def forward(self, x1, x2): | |
| x1 = self.conv1(x1) | |
| x2 = self.conv2(x2) | |
| x1 = x1 + x1 * self.CA1(x1) | |
| x2 = x2 + x2 * self.CA2(x2) | |
| nx1 = x1 + x1 * self.SA2(x2) | |
| nx2 = x2 + x2 * self.SA1(x1) | |
| res = self.conv(torch.cat([nx1, nx2], dim=1)) | |
| res = res + x1 | |
| edg = res | |
| ske = res | |
| return res, edg, ske | |
| class AFM(nn.Module): | |
| def __init__(self, channel): | |
| super(AFM, self).__init__() | |
| self.max_pool = nn.AdaptiveMaxPool2d(1) | |
| self.sigmoid = nn.Sigmoid() | |
| self.conv1_1 = nn.Conv2d(channel, channel, kernel_size=1) | |
| self.ca1 = ChannelAttention(channel) | |
| self.ca2 = ChannelAttention(channel) | |
| self.sa = SpatialAttention() | |
| self.sal_conv = nn.Sequential( | |
| BasicConv2d(channel, channel, 3, padding=1), | |
| BasicConv2d(channel, channel, 3, padding=1) | |
| ) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x1, x2): | |
| x2 = self.sigmoid(self.max_pool(x2)) | |
| xb = x2 * x1 | |
| x = self.conv1_1(xb) | |
| x_c = self.ca1(x) * x | |
| x_d = self.ca2(x) * x | |
| s_mea = self.sal_conv((self.sa(x_c)) * x_c) + x1 + x2 + xb | |
| ske = s_mea | |
| e_pred = s_mea | |
| return s_mea, e_pred, ske | |
| class DummyMOM(nn.Module): | |
| def __init__(self, channel): | |
| super(DummyMOM, self).__init__() | |
| self.conv1 = nn.Identity() # 保持输入输出一致 | |
| self.conv2 = nn.Identity() # 保持输入输出一致 | |
| # 调整为64个输入通道 | |
| self.conv = nn.Conv2d(64, 32, kernel_size=1) # 1x1卷积调整通道数 | |
| def forward(self, x1, x2): | |
| # 先做拼接,然后调整通道数为32 | |
| res = self.conv(torch.cat([x1, x2], dim=1)) | |
| edg = res | |
| ske = res | |
| return res, edg, ske | |
| class YUEM(nn.Module): | |
| def __init__(self, channel): | |
| super(YUEM, self).__init__() | |
| self.channel = channel | |
| self.m1 = MModule(self.channel) | |
| self.m2 = MNodule(self.channel) | |
| self.mha = MultiHeadAttention(channel) | |
| def forward(self, x1, x2): | |
| x1 = self.m1(x1) | |
| x21 = self.m2(x2) | |
| res = self.mha(x1, x21, x2) | |
| edg = res | |
| ske = res | |
| return res, edg, ske | |
| class MTG(nn.Module): | |
| def __init__(self, channel): | |
| super(MTG, self).__init__() | |
| self.ccs = nn.ModuleList([nn.Sequential( | |
| BasicConv2d(3 * channel, channel, kernel_size=3, padding=1), | |
| BasicConv2d(channel, channel, kernel_size=3, padding=1) | |
| ) for i in range(5)]) | |
| def forward(self, x_sal, x_edg, x_ske): | |
| x_combined = torch.cat((x_sal, x_edg,x_ske), dim=1) | |
| x_sal_n = self.ccs[0](x_combined) | |
| return x_sal_n | |
| class MMS(nn.Module): | |
| def __init__(self, pretrained=True, channel=32): | |
| super(MMS, self).__init__() | |
| self.backbone = mobilenet_v2(pretrained) | |
| self.Translayer1 = Reduction(16, channel) | |
| self.Translayer2 = Reduction(24, channel) | |
| self.Translayer3 = Reduction(32, channel) | |
| self.Translayer4 = Reduction(96, channel) | |
| self.Translayer5 = Reduction(320, channel) | |
| self.trans_conv1 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2, | |
| padding=0, dilation=1, bias=False) | |
| self.trans_conv2 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2, | |
| padding=0, dilation=1, bias=False) | |
| self.trans_conv3 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2, | |
| padding=0, dilation=1, bias=False) | |
| self.trans_conv4 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2, | |
| padding=0, dilation=1, bias=False) | |
| self.mom = MOM(channel) | |
| # self.mom = DummyMOM(channel) | |
| self.afm = AFM(channel) | |
| # self.afm = DummyMOM(channel) | |
| self.yuem = YUEM(channel) | |
| # self.yuem = DummyMOM(channel) | |
| self.sigmoid = nn.Sigmoid() | |
| self.sal_features = features(channel) | |
| self.edg_features = features(channel) | |
| self.ske_features = features(channel) | |
| self.MTG = MTG(channel) | |
| self.ccs = nn.ModuleList([nn.Sequential( | |
| BasicConv2d(3 * channel, channel, kernel_size=3, padding=1), | |
| BasicConv2d(channel, channel, kernel_size=3, padding=1) | |
| ) for i in range(5)]) | |
| self.cme = nn.ModuleList([nn.Sequential( | |
| BasicConv2d(3 * channel, channel, kernel_size=3, padding=1), | |
| BasicConv2d(channel, channel, kernel_size=3, padding=1) | |
| ) for i in range(5)]) | |
| self.cms = nn.ModuleList([nn.Sequential( | |
| BasicConv2d(3 * channel, channel, kernel_size=3, padding=1), | |
| BasicConv2d(channel, channel, kernel_size=3, padding=1) | |
| ) for i in range(5)]) | |
| self.conv_cats = nn.ModuleList([nn.Sequential( | |
| BasicConv2d(2 * channel, channel, kernel_size=3, padding=1), | |
| BasicConv2d(channel, channel, kernel_size=3, padding=1) | |
| ) for i in range(12)]) | |
| self.cus = nn.ModuleList([conv_upsamle(channel) for i in range(12)]) | |
| self.prediction = nn.ModuleList([ | |
| nn.Sequential( | |
| BasicConv2d(channel, channel, kernel_size=3, padding=1), | |
| nn.Conv2d(channel, 1, kernel_size=1) | |
| ) for i in range(3) | |
| ]) | |
| self.S1 = nn.Sequential( | |
| BasicConv2d(channel, channel, 3, padding=1), | |
| nn.Conv2d(channel, 1, 1) | |
| ) | |
| self.S2 = nn.Sequential( | |
| BasicConv2d(channel, channel, 3, padding=1), | |
| nn.Conv2d(channel, 1, 1) | |
| ) | |
| self.S3 = nn.Sequential( | |
| BasicConv2d(channel, channel, 3, padding=1), | |
| nn.Conv2d(channel, 1, 1) | |
| ) | |
| self.S4 = nn.Sequential( | |
| BasicConv2d(channel, channel, 3, padding=1), | |
| nn.Conv2d(channel, 1, 1) | |
| ) | |
| self.S5 = nn.Sequential( | |
| BasicConv2d(channel, channel, 3, padding=1), | |
| nn.Conv2d(channel, 1, 1) | |
| ) | |
| def forward(self, x): | |
| size = x.size()[2:] | |
| conv1, conv2, conv3, conv4, conv5 = self.backbone(x) | |
| conv1 = self.Translayer1(conv1) | |
| conv2 = self.Translayer2(conv2) | |
| conv3 = self.Translayer3(conv3) | |
| conv4 = self.Translayer4(conv4) | |
| conv5 = self.Translayer5(conv5) | |
| rgc5, edg5, ske5 = self.afm(conv5, conv5) | |
| rgc4, edg4, ske4 = self.yuem(conv4, self.trans_conv4(conv5)) | |
| rgc3, edg3, ske3 = self.yuem(conv3, self.trans_conv3(conv4)) | |
| rgc2, edg2, ske2 = self.mom(conv2, self.trans_conv2(conv3)) | |
| rgc1, edg1, ske1 = self.mom(conv1, self.trans_conv1(conv2)) | |
| x_sal1, x_sal2, x_sal3, x_sal4, x_sal5 = self.sal_features(rgc1, rgc2, rgc3, rgc4, rgc5) | |
| x_edg1, x_edg2, x_edg3, x_edg4, x_edg5 = self.edg_features(edg1, edg2, edg3, edg4, edg5) | |
| x_ske1, x_ske2, x_ske3, x_ske4, x_ske5 = self.ske_features(ske1, ske2, ske3, ske4, ske5) | |
| x_sal5_n = self.ccs[0](torch.cat((x_sal5, x_edg5, x_sal5), 1)) + x_sal5 | |
| x_edg5_n = self.cme[0](torch.cat((x_sal5, x_edg5, x_sal5), 1)) + x_edg5 | |
| x_ske5_n = self.cms[0](torch.cat((x_sal5, x_edg5, x_ske5), 1)) + x_ske5 | |
| x_sal4 = self.conv_cats[0](torch.cat((x_sal4, self.cus[0](x_sal5_n, x_sal4)), 1)) | |
| x_edg4 = self.conv_cats[1](torch.cat((x_edg4, self.cus[1](x_edg5_n, x_edg4)), 1)) | |
| x_ske4 = self.conv_cats[2](torch.cat((x_ske4, self.cus[2](x_ske5_n, x_ske4)), 1)) | |
| x_sal4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_sal4 | |
| x_edg4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_edg4 | |
| x_ske4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_ske4 | |
| x_sal3 = self.conv_cats[3](torch.cat((x_sal3, self.cus[3](x_sal4_n, x_sal3)), 1)) | |
| x_edg3 = self.conv_cats[4](torch.cat((x_edg3, self.cus[4](x_edg4_n, x_edg3)), 1)) | |
| x_ske3 = self.conv_cats[5](torch.cat((x_ske3, self.cus[5](x_ske4_n, x_ske3)), 1)) | |
| x_sal3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_sal3 | |
| x_edg3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_edg3 | |
| x_ske3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_ske3 | |
| x_sal2 = self.conv_cats[6](torch.cat((x_sal2, self.cus[6](x_sal3_n, x_sal2)), 1)) | |
| x_edg2 = self.conv_cats[7](torch.cat((x_edg2, self.cus[7](x_edg3_n, x_edg2)), 1)) | |
| x_ske2 = self.conv_cats[8](torch.cat((x_ske2, self.cus[8](x_ske3_n, x_ske2)), 1)) | |
| x_sal2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_sal2 | |
| x_edg2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_edg2 | |
| x_ske2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_ske2 | |
| x_sal1 = self.conv_cats[9](torch.cat((x_sal1, self.cus[9](x_sal2_n, x_sal1)), 1)) | |
| x_edg1 = self.conv_cats[10](torch.cat((x_edg1, self.cus[10](x_edg2_n, x_edg1)), 1)) | |
| x_ske1 = self.conv_cats[11](torch.cat((x_ske1, self.cus[11](x_ske2_n, x_ske1)), 1)) | |
| x_sal1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_sal1 | |
| x_edg1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_edg1 | |
| x_ske1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_ske1 | |
| sal_out = self.prediction[0](x_sal1_n) | |
| edg_out = self.prediction[1](x_edg1_n) | |
| ske_out = self.prediction[2](x_ske1_n) | |
| x_sal2_n = self.prediction[0](x_sal2_n) | |
| x_edg2_n = self.prediction[1](x_edg2_n) | |
| x_ske2_n = self.prediction[2](x_ske2_n) | |
| x_sal3_n = self.prediction[0](x_sal3_n) | |
| x_edg3_n = self.prediction[1](x_edg3_n) | |
| x_ske3_n = self.prediction[2](x_ske3_n) | |
| x_sal4_n = self.prediction[0](x_sal4_n) | |
| x_edg4_n = self.prediction[1](x_edg4_n) | |
| x_ske4_n = self.prediction[2](x_ske4_n) | |
| x_sal5_n = self.prediction[0](x_sal5_n) | |
| x_edg5_n = self.prediction[1](x_edg5_n) | |
| x_ske5_n = self.prediction[2](x_ske5_n) | |
| sal_out = F.interpolate(sal_out, size=size, mode='bilinear', align_corners=True) | |
| edg_out = F.interpolate(edg_out, size=size, mode='bilinear', align_corners=True) | |
| ske_out = F.interpolate(ske_out, size=size, mode='bilinear', align_corners=True) | |
| sal2 = F.interpolate(x_sal2_n, size=size, mode='bilinear', align_corners=True) | |
| edg2 = F.interpolate(x_edg2_n, size=size, mode='bilinear', align_corners=True) | |
| ske2 = F.interpolate(x_ske2_n, size=size, mode='bilinear', align_corners=True) | |
| sal3 = F.interpolate(x_sal3_n, size=size, mode='bilinear', align_corners=True) | |
| edg3 = F.interpolate(x_edg3_n, size=size, mode='bilinear', align_corners=True) | |
| ske3 = F.interpolate(x_ske3_n, size=size, mode='bilinear', align_corners=True) | |
| sal4 = F.interpolate(x_sal4_n, size=size, mode='bilinear', align_corners=True) | |
| edg4 = F.interpolate(x_edg4_n, size=size, mode='bilinear', align_corners=True) | |
| ske4 = F.interpolate(x_ske4_n, size=size, mode='bilinear', align_corners=True) | |
| sal5 = F.interpolate(x_sal5_n, size=size, mode='bilinear', align_corners=True) | |
| edg5 = F.interpolate(x_edg5_n, size=size, mode='bilinear', align_corners=True) | |
| ske5 = F.interpolate(x_ske5_n, size=size, mode='bilinear', align_corners=True) | |
| return x_sal1_n, sal_out, self.sigmoid(sal_out), edg_out, self.sigmoid(edg_out), sal2, edg2, self.sigmoid( | |
| sal2), self.sigmoid(edg2), sal3, edg3, self.sigmoid(sal3), self.sigmoid(edg3), sal4, edg4, self.sigmoid( | |
| sal4), self.sigmoid(edg4), sal5, edg5, self.sigmoid(sal5), self.sigmoid(edg5), ske_out, self.sigmoid( | |
| ske_out), ske2, self.sigmoid(ske2), ske3, self.sigmoid(ske3), ske4, self.sigmoid(ske4), ske5, self.sigmoid( | |
| ske5) | |
| # return x_sal1_n, sal_out, self.sigmoid(sal_out), edg_out, self.sigmoid(edg_out), sal2, edg2, self.sigmoid( | |
| # sal2), self.sigmoid(edg2), sal3, edg3, self.sigmoid(sal3), self.sigmoid(edg3), sal4, edg4, self.sigmoid( | |
| # sal4), self.sigmoid(edg4), sal5, edg5, self.sigmoid(sal5), self.sigmoid(edg5) | |