Spaces:
Sleeping
Sleeping
| # Author: Qiuyi | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils.parametrizations import spectral_norm | |
| class _Combo(nn.Module): | |
| def forward(self, input): | |
| return self.model(input) | |
| class LinearCombo(_Combo): | |
| def __init__(self, in_features, out_features, activation=nn.LeakyReLU(0.2)): | |
| super().__init__() | |
| self.model = nn.Sequential( | |
| nn.Linear(in_features, out_features), | |
| activation | |
| ) | |
| class SNLinearCombo(_Combo): | |
| def __init__(self, in_features, out_features, activation=nn.LeakyReLU(0.2)): | |
| super().__init__() | |
| self.model = nn.Sequential( | |
| spectral_norm(nn.Linear(in_features, out_features)), | |
| activation | |
| ) | |
| class MLP(nn.Module): | |
| """Regular fully connected network generating features. | |
| Args: | |
| in_features: The number of input features. | |
| out_feature: The number of output features. | |
| layer_width: The widths of the hidden layers. | |
| combo: The layer combination to be stacked up. | |
| Shape: | |
| - Input: `(N, H_in)` where H_in = in_features. | |
| - Output: `(N, H_out)` where H_out = out_features. | |
| """ | |
| def __init__( | |
| self, in_features: int, out_features:int, layer_width: list, | |
| combo = LinearCombo | |
| ): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.layer_width = list(layer_width) | |
| self.model = self._build_model(combo) | |
| def forward(self, input): | |
| return self.model(input) | |
| def _build_model(self, combo): | |
| model = nn.Sequential() | |
| idx = -1 | |
| for idx, (in_ftr, out_ftr) in enumerate(self.layer_sizes[:-1]): | |
| model.add_module(str(idx), combo(in_ftr, out_ftr)) | |
| model.add_module(str(idx+1), nn.Linear(*self.layer_sizes[-1])) # type:ignore | |
| return model | |
| def layer_sizes(self): | |
| return list(zip([self.in_features] + self.layer_width, | |
| self.layer_width + [self.out_features])) | |
| class SNMLP(MLP): | |
| def __init__( | |
| self, in_features: int, out_features: int, layer_width: list, | |
| combo=SNLinearCombo): | |
| super().__init__(in_features, out_features, layer_width, combo) | |
| def _build_model(self, combo): | |
| model = nn.Sequential() | |
| idx = -1 | |
| for idx, (in_ftr, out_ftr) in enumerate(self.layer_sizes[:-1]): | |
| model.add_module(str(idx), combo(in_ftr, out_ftr)) | |
| model.add_module(str(idx+1), spectral_norm(nn.Linear(*self.layer_sizes[-1]))) # type:ignore | |
| return model |