saliency-detection-demo / model /CyueNet_models.py
kunkk's picture
Upload 2 files
20b1b91 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from timm.models.layers import trunc_normal_
from einops import rearrange
import math
from model.MobileNetV2 import mobilenet_v2
from torch.nn import Parameter
class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Reduction(nn.Module):
def __init__(self, in_channel, out_channel):
super(Reduction, self).__init__()
self.reduce = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, 3, padding=1),
BasicConv2d(out_channel, out_channel, 3, padding=1)
)
def forward(self, x):
return self.reduce(x)
class TopDownLayer(nn.Module):
def __init__(self, channel):
super(TopDownLayer, self).__init__()
self.conv = nn.Sequential(nn.Conv2d(channel, channel, 3, 1, 1, bias=False), nn.BatchNorm2d(channel))
self.relu = nn.ReLU()
self.channel_compress = nn.Sequential(
nn.Conv2d(channel * 2, channel, 1, bias=False),
nn.BatchNorm2d(channel),
nn.ReLU()
)
def forward(self, x, x2):
res1 = self.conv(x)
res1 = self.relu(res1)
res1 = F.interpolate(res1, x2.size()[2:], mode='bilinear', align_corners=True)
res_cat = torch.cat((res1, x2), dim=1)
resl = self.channel_compress(res_cat)
return resl
class MultiHeadAttention(nn.Module):
def __init__(self, head=8, d_model=32, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert (d_model % head == 0)
self.d_k = d_model // head
self.head = head
self.d_model = d_model
self.linear_query = nn.Linear(d_model, d_model)
self.linear_key = nn.Linear(d_model, d_model)
self.linear_value = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(p=dropout)
self.attn = None
self.inb = nn.Linear(32, d_model)
def self_attention(self, query, key, value, mask=None):
d_k = query.shape[-1]
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
self_attn = F.softmax(scores, dim=-1)
# self.attn = self_attn if self.attn is None else self.attn + self_attn
if self.dropout is not None:
self_attn = self.dropout(self_attn)
return torch.matmul(self_attn, value), self_attn
def forward(self, query, key, value, mask=None):
n_batch = query.size(0)
query = query.flatten(start_dim=2).permute(0, 2, 1)
query = self.inb(query)
key = key.flatten(start_dim=2).permute(0, 2, 1)
key = self.inb(key)
value = value.flatten(start_dim=2).permute(0, 2, 1)
value = self.inb(value)
x, self.attn = self.self_attention(query, key, value, mask=mask)
x = x.permute(0, 2, 1)
embedding_dim = x.size(-1)
d_k = h = int(embedding_dim ** 0.5)
x = einops.rearrange(x, 'b n (d_k h) -> b n d_k h', d_k=d_k, h=h)
return x
class Upsample(nn.Module):
def __init__(self):
super(Upsample, self).__init__()
def forward(self, x, x2):
x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True)
return x
class MultiScaleAttention(nn.Module):
def __init__(self, channel):
super(MultiScaleAttention, self).__init__()
# SPatial attention for each branch
self.attention_branches = nn.ModuleList([SpatialAttention() for _ in range(5)])
self.upsample = Upsample()
self.conv_reduce = nn.Conv2d(channel * 6, channel, kernel_size=1)
def forward(self, x0, x1, x2, x3, x4, x5):
x0_att = self.attention_branches[0](x0) * x0
x1_att = self.attention_branches[0](x1) * x1
x2_att = self.attention_branches[0](x2) * x2
x3_att = self.attention_branches[0](x3) * x3
x4_att = self.attention_branches[0](x4) * x4
x5_att = self.attention_branches[0](x5) * x5
x1_att_up = self.upsample(x1_att, x0)
x2_att_up = self.upsample(x2_att, x0)
x3_att_up = self.upsample(x3_att, x0)
x4_att_up = self.upsample(x4_att, x0)
x5_att_up = self.upsample(x5_att, x0)
x_cat = torch.cat((x0_att, x1_att_up, x2_att_up, x3_att_up, x4_att_up, x5_att_up), dim=1)
x_out = self.conv_reduce(x_cat)
return x_out
class Basic2(nn.Module):
def __init__(self, in_channel, out_channel):
super(Basic2, self).__init__()
self.relu = nn.ReLU(True)
# join
self.channel_attention = ChannelAttention(out_channel)
self.channel_attention = SpatialAttention()
self.branch0 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
)
self.branch1 = nn.Sequential(
BasicConv2d(in_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
)
self.branch2 = nn.Sequential(
BasicConv2d(in_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
)
self.branch3 = nn.Sequential(
BasicConv2d(in_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
)
self.branch4 = nn.Sequential(
BasicConv2d(in_channel, out_channel, kernel_size=(1, 9), padding=(0, 4)),
BasicConv2d(out_channel, out_channel, kernel_size=(9, 1), padding=(4, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=9, dilation=9)
)
self.branch5 = nn.Sequential(
BasicConv2d(in_channel, out_channel, kernel_size=(1, 11), padding=(0, 5)),
BasicConv2d(out_channel, out_channel, kernel_size=(11, 1), padding=(5, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=11, dilation=11)
)
self.multi_scale_attention = MultiScaleAttention(out_channel)
self.conv_combine = BasicConv2d(in_channel, in_channel, kernel_size=3, padding=1)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x4 = self.branch4(x)
x5 = self.branch5(x)
x_att = self.multi_scale_attention(x0, x1, x2, x3, x4, x5)
x_combind = self.conv_combine(x_att)
x = x_combind + x
return x
class ChannelAttention(nn.Module):
def __init__(self, in_planes):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 2, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_planes // 2, in_planes, 1, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x1 = torch.cat([avg_out, max_out], dim=1)
x2 = self.conv1(x1)
return self.sigmoid(x2)
class MModule(nn.Module):
def __init__(self, channel):
super(MModule, self).__init__()
self.basic = Basic2(channel, channel)
self.SA = SpatialAttention()
self.CA = ChannelAttention(channel)
def forward(self, x):
x_mix = self.basic(x)
x_mix = x_mix * self.CA(x_mix) + x_mix
x_mix1 = x_mix * self.SA(x_mix) + x_mix
x_mix1 = x_mix1 + x
return x_mix1
class MNodule(nn.Module):
def __init__(self, channel):
super(MNodule, self).__init__()
self.atrconv1 = BasicConv2d(channel, channel, 3, padding=3, dilation=3)
self.atrconv2 = BasicConv2d(channel, channel, 3, padding=5, dilation=5)
self.atrconv3 = BasicConv2d(channel, channel, 3, padding=7, dilation=7)
self.branch1 = nn.Sequential(
BasicConv2d(channel, channel, 1),
BasicConv2d(channel, channel, kernel_size=(1, 3), padding=(0, 1)),
BasicConv2d(channel, channel, kernel_size=(3, 1), padding=(1, 0))
)
self.branch2 = nn.Sequential(
BasicConv2d(channel, channel, 1),
BasicConv2d(channel, channel, kernel_size=(1, 5), padding=(0, 2)),
BasicConv2d(channel, channel, kernel_size=(5, 1), padding=(2, 0))
)
self.branch3 = nn.Sequential(
BasicConv2d(channel, channel, 1),
BasicConv2d(channel, channel, kernel_size=(1, 7), padding=(0, 3)),
BasicConv2d(channel, channel, kernel_size=(7, 1), padding=(3, 0))
)
self.conv_cat1 = BasicConv2d(2 * channel, channel, 3, padding=1)
self.conv_cat2 = BasicConv2d(2 * channel, channel, 3, padding=1)
self.conv_cat3 = BasicConv2d(2 * channel, channel, 3, padding=1)
self.conv1_1 = BasicConv2d(channel, channel, 1)
self.SA = SpatialAttention()
self.CA = ChannelAttention(channel)
self.sal_conv = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
BasicConv2d(channel, channel, 3, padding=1)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = self.branch1(x)
x_atr1 = self.atrconv1(x)
s_mfeb1 = self.conv_cat1(torch.cat((x1, x_atr1), 1)) + x
x2 = self.branch2(s_mfeb1)
x_atr2 = self.atrconv2(s_mfeb1)
s_mfeb2 = self.conv_cat2(torch.cat((x2, x_atr2), 1)) + s_mfeb1 + x
x3 = self.branch3(s_mfeb2)
x_atr3 = self.atrconv3(s_mfeb2)
s_mfeb3 = self.conv_cat3(torch.cat((x3, x_atr3), 1)) + s_mfeb1 + s_mfeb2 + x
x_m = self.conv1_1(s_mfeb3)
x_ca = self.CA(x_m) * x_m
x_e = self.CA(x_m) * x_m
x_mix = self.sal_conv((self.SA(x_ca)) * x_ca) + s_mfeb1 + s_mfeb2 + s_mfeb3 + x
return x_mix
class TransBasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1, bias=False):
super(TransBasicConv2d, self).__init__()
self.Deconv = nn.ConvTranspose2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.Deconv(x)
x = self.bn(x)
x = self.relu(x)
return x
class features(nn.Module):
def __init__(self, channel):
super(features, self).__init__()
self.conv1 = BasicConv2d(channel, channel, 1)
self.conv2 = BasicConv2d(channel, channel, 1)
self.conv3 = BasicConv2d(channel, channel, 1)
self.conv4 = BasicConv2d(channel, channel, 1)
self.conv5 = BasicConv2d(channel, channel, 1)
def forward(self, x1, x2, x3, x4, x5):
x1 = self.conv1(x1)
x2 = self.conv2(x2)
x3 = self.conv3(x3)
x4 = self.conv4(x4)
x5 = self.conv5(x5)
return x1, x2, x3, x4, x5
class conv_upsamle(nn.Module):
def __init__(self, channel):
super(conv_upsamle, self).__init__()
self.conv = BasicConv2d(channel, channel, 3, padding=1)
def forward(self, x, target):
if x.size()[2:] != target.size()[2:]:
x = F.interpolate(x, size=target.size()[2:], mode='bilinear', align_corners=True)
x = self.conv(x)
return x
class AP_MP(nn.Module):
def __init__(self, stride=2):
super(AP_MP, self).__init__()
self.sz = stride
self.gapLayer = nn.AvgPool2d(kernel_size=self.sz, stride=self.sz)
self.gmpLayer = nn.MaxPool2d(kernel_size=self.sz, stride=self.sz)
def forward(self, x1, x2):
B, C, H, W = x1.size()
apimg = self.gapLayer(x1)
mpimg = self.gmpLayer(x2)
byimg = torch.norm(abs(apimg - mpimg), p=2, dim=1, keepdim=True)
return byimg
class MOM(nn.Module):
def __init__(self, channel):
super(MOM, self).__init__()
self.channel = channel
self.conv1 = BasicConv2d(channel, channel, 3, padding=1)
self.conv2 = BasicConv2d(channel, channel, 3, padding=1)
self.CA1 = ChannelAttention(self.channel)
self.CA2 = ChannelAttention(self.channel)
self.SA1 = SpatialAttention()
self.SA2 = SpatialAttention()
self.glbamp = AP_MP()
self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = BasicConv2d(channel * 2 , channel, kernel_size=1, stride=1)
self.upSA = SpatialAttention()
def forward(self, x1, x2):
x1 = self.conv1(x1)
x2 = self.conv2(x2)
x1 = x1 + x1 * self.CA1(x1)
x2 = x2 + x2 * self.CA2(x2)
nx1 = x1 + x1 * self.SA2(x2)
nx2 = x2 + x2 * self.SA1(x1)
res = self.conv(torch.cat([nx1, nx2], dim=1))
res = res + x1
edg = res
ske = res
return res, edg, ske
class AFM(nn.Module):
def __init__(self, channel):
super(AFM, self).__init__()
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.sigmoid = nn.Sigmoid()
self.conv1_1 = nn.Conv2d(channel, channel, kernel_size=1)
self.ca1 = ChannelAttention(channel)
self.ca2 = ChannelAttention(channel)
self.sa = SpatialAttention()
self.sal_conv = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
BasicConv2d(channel, channel, 3, padding=1)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x1, x2):
x2 = self.sigmoid(self.max_pool(x2))
xb = x2 * x1
x = self.conv1_1(xb)
x_c = self.ca1(x) * x
x_d = self.ca2(x) * x
s_mea = self.sal_conv((self.sa(x_c)) * x_c) + x1 + x2 + xb
ske = s_mea
e_pred = s_mea
return s_mea, e_pred, ske
class DummyMOM(nn.Module):
def __init__(self, channel):
super(DummyMOM, self).__init__()
self.conv1 = nn.Identity() # 保持输入输出一致
self.conv2 = nn.Identity() # 保持输入输出一致
# 调整为64个输入通道
self.conv = nn.Conv2d(64, 32, kernel_size=1) # 1x1卷积调整通道数
def forward(self, x1, x2):
# 先做拼接,然后调整通道数为32
res = self.conv(torch.cat([x1, x2], dim=1))
edg = res
ske = res
return res, edg, ske
class YUEM(nn.Module):
def __init__(self, channel):
super(YUEM, self).__init__()
self.channel = channel
self.m1 = MModule(self.channel)
self.m2 = MNodule(self.channel)
self.mha = MultiHeadAttention(channel)
def forward(self, x1, x2):
x1 = self.m1(x1)
x21 = self.m2(x2)
res = self.mha(x1, x21, x2)
edg = res
ske = res
return res, edg, ske
class MTG(nn.Module):
def __init__(self, channel):
super(MTG, self).__init__()
self.ccs = nn.ModuleList([nn.Sequential(
BasicConv2d(3 * channel, channel, kernel_size=3, padding=1),
BasicConv2d(channel, channel, kernel_size=3, padding=1)
) for i in range(5)])
def forward(self, x_sal, x_edg, x_ske):
x_combined = torch.cat((x_sal, x_edg,x_ske), dim=1)
x_sal_n = self.ccs[0](x_combined)
return x_sal_n
class MMS(nn.Module):
def __init__(self, pretrained=True, channel=32):
super(MMS, self).__init__()
self.backbone = mobilenet_v2(pretrained)
self.Translayer1 = Reduction(16, channel)
self.Translayer2 = Reduction(24, channel)
self.Translayer3 = Reduction(32, channel)
self.Translayer4 = Reduction(96, channel)
self.Translayer5 = Reduction(320, channel)
self.trans_conv1 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2,
padding=0, dilation=1, bias=False)
self.trans_conv2 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2,
padding=0, dilation=1, bias=False)
self.trans_conv3 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2,
padding=0, dilation=1, bias=False)
self.trans_conv4 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2,
padding=0, dilation=1, bias=False)
self.mom = MOM(channel)
# self.mom = DummyMOM(channel)
self.afm = AFM(channel)
# self.afm = DummyMOM(channel)
self.yuem = YUEM(channel)
# self.yuem = DummyMOM(channel)
self.sigmoid = nn.Sigmoid()
self.sal_features = features(channel)
self.edg_features = features(channel)
self.ske_features = features(channel)
self.MTG = MTG(channel)
self.ccs = nn.ModuleList([nn.Sequential(
BasicConv2d(3 * channel, channel, kernel_size=3, padding=1),
BasicConv2d(channel, channel, kernel_size=3, padding=1)
) for i in range(5)])
self.cme = nn.ModuleList([nn.Sequential(
BasicConv2d(3 * channel, channel, kernel_size=3, padding=1),
BasicConv2d(channel, channel, kernel_size=3, padding=1)
) for i in range(5)])
self.cms = nn.ModuleList([nn.Sequential(
BasicConv2d(3 * channel, channel, kernel_size=3, padding=1),
BasicConv2d(channel, channel, kernel_size=3, padding=1)
) for i in range(5)])
self.conv_cats = nn.ModuleList([nn.Sequential(
BasicConv2d(2 * channel, channel, kernel_size=3, padding=1),
BasicConv2d(channel, channel, kernel_size=3, padding=1)
) for i in range(12)])
self.cus = nn.ModuleList([conv_upsamle(channel) for i in range(12)])
self.prediction = nn.ModuleList([
nn.Sequential(
BasicConv2d(channel, channel, kernel_size=3, padding=1),
nn.Conv2d(channel, 1, kernel_size=1)
) for i in range(3)
])
self.S1 = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
nn.Conv2d(channel, 1, 1)
)
self.S2 = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
nn.Conv2d(channel, 1, 1)
)
self.S3 = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
nn.Conv2d(channel, 1, 1)
)
self.S4 = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
nn.Conv2d(channel, 1, 1)
)
self.S5 = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
nn.Conv2d(channel, 1, 1)
)
def forward(self, x):
size = x.size()[2:]
conv1, conv2, conv3, conv4, conv5 = self.backbone(x)
conv1 = self.Translayer1(conv1)
conv2 = self.Translayer2(conv2)
conv3 = self.Translayer3(conv3)
conv4 = self.Translayer4(conv4)
conv5 = self.Translayer5(conv5)
rgc5, edg5, ske5 = self.afm(conv5, conv5)
rgc4, edg4, ske4 = self.yuem(conv4, self.trans_conv4(conv5))
rgc3, edg3, ske3 = self.yuem(conv3, self.trans_conv3(conv4))
rgc2, edg2, ske2 = self.mom(conv2, self.trans_conv2(conv3))
rgc1, edg1, ske1 = self.mom(conv1, self.trans_conv1(conv2))
x_sal1, x_sal2, x_sal3, x_sal4, x_sal5 = self.sal_features(rgc1, rgc2, rgc3, rgc4, rgc5)
x_edg1, x_edg2, x_edg3, x_edg4, x_edg5 = self.edg_features(edg1, edg2, edg3, edg4, edg5)
x_ske1, x_ske2, x_ske3, x_ske4, x_ske5 = self.ske_features(ske1, ske2, ske3, ske4, ske5)
x_sal5_n = self.ccs[0](torch.cat((x_sal5, x_edg5, x_sal5), 1)) + x_sal5
x_edg5_n = self.cme[0](torch.cat((x_sal5, x_edg5, x_sal5), 1)) + x_edg5
x_ske5_n = self.cms[0](torch.cat((x_sal5, x_edg5, x_ske5), 1)) + x_ske5
x_sal4 = self.conv_cats[0](torch.cat((x_sal4, self.cus[0](x_sal5_n, x_sal4)), 1))
x_edg4 = self.conv_cats[1](torch.cat((x_edg4, self.cus[1](x_edg5_n, x_edg4)), 1))
x_ske4 = self.conv_cats[2](torch.cat((x_ske4, self.cus[2](x_ske5_n, x_ske4)), 1))
x_sal4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_sal4
x_edg4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_edg4
x_ske4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_ske4
x_sal3 = self.conv_cats[3](torch.cat((x_sal3, self.cus[3](x_sal4_n, x_sal3)), 1))
x_edg3 = self.conv_cats[4](torch.cat((x_edg3, self.cus[4](x_edg4_n, x_edg3)), 1))
x_ske3 = self.conv_cats[5](torch.cat((x_ske3, self.cus[5](x_ske4_n, x_ske3)), 1))
x_sal3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_sal3
x_edg3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_edg3
x_ske3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_ske3
x_sal2 = self.conv_cats[6](torch.cat((x_sal2, self.cus[6](x_sal3_n, x_sal2)), 1))
x_edg2 = self.conv_cats[7](torch.cat((x_edg2, self.cus[7](x_edg3_n, x_edg2)), 1))
x_ske2 = self.conv_cats[8](torch.cat((x_ske2, self.cus[8](x_ske3_n, x_ske2)), 1))
x_sal2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_sal2
x_edg2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_edg2
x_ske2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_ske2
x_sal1 = self.conv_cats[9](torch.cat((x_sal1, self.cus[9](x_sal2_n, x_sal1)), 1))
x_edg1 = self.conv_cats[10](torch.cat((x_edg1, self.cus[10](x_edg2_n, x_edg1)), 1))
x_ske1 = self.conv_cats[11](torch.cat((x_ske1, self.cus[11](x_ske2_n, x_ske1)), 1))
x_sal1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_sal1
x_edg1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_edg1
x_ske1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_ske1
sal_out = self.prediction[0](x_sal1_n)
edg_out = self.prediction[1](x_edg1_n)
ske_out = self.prediction[2](x_ske1_n)
x_sal2_n = self.prediction[0](x_sal2_n)
x_edg2_n = self.prediction[1](x_edg2_n)
x_ske2_n = self.prediction[2](x_ske2_n)
x_sal3_n = self.prediction[0](x_sal3_n)
x_edg3_n = self.prediction[1](x_edg3_n)
x_ske3_n = self.prediction[2](x_ske3_n)
x_sal4_n = self.prediction[0](x_sal4_n)
x_edg4_n = self.prediction[1](x_edg4_n)
x_ske4_n = self.prediction[2](x_ske4_n)
x_sal5_n = self.prediction[0](x_sal5_n)
x_edg5_n = self.prediction[1](x_edg5_n)
x_ske5_n = self.prediction[2](x_ske5_n)
sal_out = F.interpolate(sal_out, size=size, mode='bilinear', align_corners=True)
edg_out = F.interpolate(edg_out, size=size, mode='bilinear', align_corners=True)
ske_out = F.interpolate(ske_out, size=size, mode='bilinear', align_corners=True)
sal2 = F.interpolate(x_sal2_n, size=size, mode='bilinear', align_corners=True)
edg2 = F.interpolate(x_edg2_n, size=size, mode='bilinear', align_corners=True)
ske2 = F.interpolate(x_ske2_n, size=size, mode='bilinear', align_corners=True)
sal3 = F.interpolate(x_sal3_n, size=size, mode='bilinear', align_corners=True)
edg3 = F.interpolate(x_edg3_n, size=size, mode='bilinear', align_corners=True)
ske3 = F.interpolate(x_ske3_n, size=size, mode='bilinear', align_corners=True)
sal4 = F.interpolate(x_sal4_n, size=size, mode='bilinear', align_corners=True)
edg4 = F.interpolate(x_edg4_n, size=size, mode='bilinear', align_corners=True)
ske4 = F.interpolate(x_ske4_n, size=size, mode='bilinear', align_corners=True)
sal5 = F.interpolate(x_sal5_n, size=size, mode='bilinear', align_corners=True)
edg5 = F.interpolate(x_edg5_n, size=size, mode='bilinear', align_corners=True)
ske5 = F.interpolate(x_ske5_n, size=size, mode='bilinear', align_corners=True)
return x_sal1_n, sal_out, self.sigmoid(sal_out), edg_out, self.sigmoid(edg_out), sal2, edg2, self.sigmoid(
sal2), self.sigmoid(edg2), sal3, edg3, self.sigmoid(sal3), self.sigmoid(edg3), sal4, edg4, self.sigmoid(
sal4), self.sigmoid(edg4), sal5, edg5, self.sigmoid(sal5), self.sigmoid(edg5), ske_out, self.sigmoid(
ske_out), ske2, self.sigmoid(ske2), ske3, self.sigmoid(ske3), ske4, self.sigmoid(ske4), ske5, self.sigmoid(
ske5)
# return x_sal1_n, sal_out, self.sigmoid(sal_out), edg_out, self.sigmoid(edg_out), sal2, edg2, self.sigmoid(
# sal2), self.sigmoid(edg2), sal3, edg3, self.sigmoid(sal3), self.sigmoid(edg3), sal4, edg4, self.sigmoid(
# sal4), self.sigmoid(edg4), sal5, edg5, self.sigmoid(sal5), self.sigmoid(edg5)