| | import torch |
| | import nltk |
| | import io |
| | import base64 |
| | from torchvision import transforms |
| | from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample |
| | class PreTrainedPipeline(): |
| | def __init__(self, path=""): |
| | """ |
| | Initialize model |
| | """ |
| | nltk.download('wordnet') |
| | self.model = BigGAN.from_pretrained(path) |
| | self.truncation = 0.1 |
| | def __call__(self, inputs: str): |
| | """ |
| | Args: |
| | inputs (:obj:`str`): |
| | a string containing some text |
| | Return: |
| | A :obj:`PIL.Image` with the raw image representation as PIL. |
| | """ |
| | class_vector = one_hot_from_names([inputs], batch_size=1) |
| | if type(class_vector) == type(None): |
| | raise ValueError("Input is not in ImageNet") |
| | noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1) |
| | noise_vector = torch.from_numpy(noise_vector) |
| | class_vector = torch.from_numpy(class_vector) |
| | with torch.no_grad(): |
| | output = self.model(noise_vector, class_vector, self.truncation) |
| | |
| | img = output[0] |
| | img = (img + 1) / 2.0 |
| | img = transforms.ToPILImage()(img) |
| | return img |