| | |
| |
|
| | |
| | |
| | |
| |
|
| | """ |
| | Flexible UNet model which takes any Torchvision backbone as encoder. |
| | Predicts multi-level feature and makes sure that they are well aligned. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torchvision |
| |
|
| | from .base import BaseModel |
| | from .utils import checkpointed |
| |
|
| |
|
| | class DecoderBlock(nn.Module): |
| | def __init__( |
| | self, previous, skip, out, num_convs=1, norm=nn.BatchNorm2d, padding="zeros" |
| | ): |
| | super().__init__() |
| |
|
| | self.upsample = nn.Upsample( |
| | scale_factor=2, mode="bilinear", align_corners=False |
| | ) |
| |
|
| | layers = [] |
| | for i in range(num_convs): |
| | conv = nn.Conv2d( |
| | previous + skip if i == 0 else out, |
| | out, |
| | kernel_size=3, |
| | padding=1, |
| | bias=norm is None, |
| | padding_mode=padding, |
| | ) |
| | layers.append(conv) |
| | if norm is not None: |
| | layers.append(norm(out)) |
| | layers.append(nn.ReLU(inplace=True)) |
| | self.layers = nn.Sequential(*layers) |
| |
|
| | def forward(self, previous, skip): |
| | upsampled = self.upsample(previous) |
| | |
| | |
| | |
| | |
| | _, _, hu, wu = upsampled.shape |
| | _, _, hs, ws = skip.shape |
| | assert (hu <= hs) and (wu <= ws), "Using ceil_mode=True in pooling?" |
| | |
| | skip = skip[:, :, :hu, :wu] |
| | return self.layers(torch.cat([upsampled, skip], dim=1)) |
| |
|
| |
|
| | class AdaptationBlock(nn.Sequential): |
| | def __init__(self, inp, out): |
| | conv = nn.Conv2d(inp, out, kernel_size=1, padding=0, bias=True) |
| | super().__init__(conv) |
| |
|
| |
|
| | class FeatureExtractor(BaseModel): |
| | default_conf = { |
| | "pretrained": True, |
| | "input_dim": 3, |
| | "output_scales": [0, 2, 4], |
| | "output_dim": 128, |
| | "encoder": "vgg16", |
| | "num_downsample": 4, |
| | "decoder": [64, 64, 64, 64], |
| | "decoder_norm": "nn.BatchNorm2d", |
| | "do_average_pooling": False, |
| | "checkpointed": False, |
| | "padding": "zeros", |
| | } |
| | mean = [0.485, 0.456, 0.406] |
| | std = [0.229, 0.224, 0.225] |
| |
|
| | def build_encoder(self, conf): |
| | assert isinstance(conf.encoder, str) |
| | if conf.pretrained: |
| | assert conf.input_dim == 3 |
| | Encoder = getattr(torchvision.models, conf.encoder) |
| | encoder = Encoder(weights="DEFAULT" if conf.pretrained else None) |
| | Block = checkpointed(torch.nn.Sequential, do=conf.checkpointed) |
| | assert max(conf.output_scales) <= conf.num_downsample |
| |
|
| | if conf.encoder.startswith("vgg"): |
| | |
| | |
| | |
| | skip_dims = [] |
| | previous_dim = None |
| | blocks = [[]] |
| | for i, layer in enumerate(encoder.features): |
| | if isinstance(layer, torch.nn.Conv2d): |
| | |
| | if i == 0 and conf.input_dim != layer.in_channels: |
| | args = {k: getattr(layer, k) for k in layer.__constants__} |
| | args.pop("output_padding") |
| | layer = torch.nn.Conv2d( |
| | **{**args, "in_channels": conf.input_dim} |
| | ) |
| | previous_dim = layer.out_channels |
| | elif isinstance(layer, torch.nn.MaxPool2d): |
| | assert previous_dim is not None |
| | skip_dims.append(previous_dim) |
| | if (conf.num_downsample + 1) == len(blocks): |
| | break |
| | blocks.append([]) |
| | if conf.do_average_pooling: |
| | assert layer.dilation == 1 |
| | layer = torch.nn.AvgPool2d( |
| | kernel_size=layer.kernel_size, |
| | stride=layer.stride, |
| | padding=layer.padding, |
| | ceil_mode=layer.ceil_mode, |
| | count_include_pad=False, |
| | ) |
| | blocks[-1].append(layer) |
| | encoder = [Block(*b) for b in blocks] |
| | elif conf.encoder.startswith("resnet"): |
| | |
| | assert conf.encoder[len("resnet") :] in ["18", "34", "50", "101"] |
| | assert conf.input_dim == 3, "Unsupported for now." |
| | block1 = torch.nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) |
| | block2 = torch.nn.Sequential(encoder.maxpool, encoder.layer1) |
| | block3 = encoder.layer2 |
| | block4 = encoder.layer3 |
| | block5 = encoder.layer4 |
| | blocks = [block1, block2, block3, block4, block5] |
| | |
| | skip_dims = [encoder.conv1.out_channels] |
| | for i in range(1, 5): |
| | modules = getattr(encoder, f"layer{i}")[-1]._modules |
| | conv = sorted(k for k in modules if k.startswith("conv"))[-1] |
| | skip_dims.append(modules[conv].out_channels) |
| | |
| | encoder = [torch.nn.Identity()] + [Block(b) for b in blocks] |
| | skip_dims = [3] + skip_dims |
| | |
| | encoder = encoder[: conf.num_downsample + 1] |
| | skip_dims = skip_dims[: conf.num_downsample + 1] |
| | else: |
| | raise NotImplementedError(conf.encoder) |
| |
|
| | assert (conf.num_downsample + 1) == len(encoder) |
| | encoder = nn.ModuleList(encoder) |
| |
|
| | return encoder, skip_dims |
| |
|
| | def _init(self, conf): |
| | |
| | self.encoder, skip_dims = self.build_encoder(conf) |
| | self.skip_dims = skip_dims |
| |
|
| | def update_padding(module): |
| | if isinstance(module, nn.Conv2d): |
| | module.padding_mode = conf.padding |
| |
|
| | if conf.padding != "zeros": |
| | self.encoder.apply(update_padding) |
| |
|
| | |
| | if conf.decoder is not None: |
| | assert len(conf.decoder) == (len(skip_dims) - 1) |
| | Block = checkpointed(DecoderBlock, do=conf.checkpointed) |
| | norm = eval(conf.decoder_norm) if conf.decoder_norm else None |
| |
|
| | previous = skip_dims[-1] |
| | decoder = [] |
| | for out, skip in zip(conf.decoder, skip_dims[:-1][::-1]): |
| | decoder.append( |
| | Block(previous, skip, out, norm=norm, padding=conf.padding) |
| | ) |
| | previous = out |
| | self.decoder = nn.ModuleList(decoder) |
| |
|
| | |
| | adaptation = [] |
| | for idx, i in enumerate(conf.output_scales): |
| | if conf.decoder is None or i == (len(self.encoder) - 1): |
| | input_ = skip_dims[i] |
| | else: |
| | input_ = conf.decoder[-1 - i] |
| |
|
| | |
| | dim = conf.output_dim |
| | if not isinstance(dim, int): |
| | dim = dim[idx] |
| |
|
| | block = AdaptationBlock(input_, dim) |
| | adaptation.append(block) |
| | self.adaptation = nn.ModuleList(adaptation) |
| | self.scales = [2**s for s in conf.output_scales] |
| |
|
| | def _forward(self, data): |
| | image = data["image"] |
| | if self.conf.pretrained: |
| | mean, std = image.new_tensor(self.mean), image.new_tensor(self.std) |
| | image = (image - mean[:, None, None]) / std[:, None, None] |
| |
|
| | skip_features = [] |
| | features = image |
| | for block in self.encoder: |
| | features = block(features) |
| | skip_features.append(features) |
| |
|
| | if self.conf.decoder: |
| | pre_features = [skip_features[-1]] |
| | for block, skip in zip(self.decoder, skip_features[:-1][::-1]): |
| | pre_features.append(block(pre_features[-1], skip)) |
| | pre_features = pre_features[::-1] |
| | else: |
| | pre_features = skip_features |
| |
|
| | out_features = [] |
| | for adapt, i in zip(self.adaptation, self.conf.output_scales): |
| | out_features.append(adapt(pre_features[i])) |
| | pred = {"feature_maps": out_features, "skip_features": skip_features} |
| | return pred |
| |
|
| | def loss(self, pred, data): |
| | raise NotImplementedError |
| |
|
| | def metrics(self, pred, data): |
| | raise NotImplementedError |
| |
|