|
|
import os |
|
|
import tempfile |
|
|
import torch |
|
|
import time |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from AdaIN import AdaINNet |
|
|
from PIL import Image |
|
|
from torchvision.utils import save_image |
|
|
from torchvision.transforms import ToPILImage |
|
|
from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range |
|
|
from glob import glob |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0): |
|
|
""" |
|
|
Given content image and style image, generate feature maps with encoder, apply |
|
|
neural style transfer with adaptive instance normalization, generate output image |
|
|
with decoder |
|
|
|
|
|
Args: |
|
|
content_tensor (torch.FloatTensor): Content image |
|
|
style_tensor (torch.FloatTensor): Style Image |
|
|
encoder: Encoder (vgg19) network |
|
|
decoder: Decoder network |
|
|
alpha (float, default=1.0): Weight of style image feature |
|
|
|
|
|
Return: |
|
|
output_tensor (torch.FloatTensor): Style Transfer output image |
|
|
""" |
|
|
|
|
|
content_enc = encoder(content_tensor) |
|
|
style_enc = encoder(style_tensor) |
|
|
|
|
|
transfer_enc = adaptive_instance_normalization(content_enc, style_enc) |
|
|
|
|
|
mix_enc = alpha * transfer_enc + (1-alpha) * content_enc |
|
|
return decoder(mix_enc) |
|
|
|
|
|
def run_adain(content_dir, style_dir, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth', alpha=1.0): |
|
|
content_pths = [Path(f) for f in glob(content_dir+'/*')] |
|
|
style_pths = [Path(f) for f in glob(style_dir+'/*')] |
|
|
|
|
|
assert len(content_pths) > 0, 'Failed to load content image' |
|
|
assert len(style_pths) > 0, 'Failed to load style image' |
|
|
|
|
|
|
|
|
out_dir = tempfile.mkdtemp() |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
vgg = torch.load(vgg_pth) |
|
|
model = AdaINNet(vgg).to(device) |
|
|
model.decoder.load_state_dict(torch.load(decoder_pth)) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
t = transform(512) |
|
|
|
|
|
|
|
|
times = [] |
|
|
|
|
|
for content_pth in content_pths: |
|
|
content_img = Image.open(content_pth) |
|
|
content_tensor = t(content_img).unsqueeze(0).to(device) |
|
|
|
|
|
for style_pth in style_pths: |
|
|
|
|
|
style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
tic = time.perf_counter() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu() |
|
|
|
|
|
|
|
|
toc = time.perf_counter() |
|
|
print("Content: " + content_pth.stem + ". Style: " \ |
|
|
+ style_pth.stem + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic)) |
|
|
times.append(toc-tic) |
|
|
|
|
|
|
|
|
out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(alpha) |
|
|
out_pth += content_pth.suffix |
|
|
save_image(out_tensor, out_pth) |
|
|
|
|
|
|
|
|
if len(times) > 1: |
|
|
times.pop(0) |
|
|
avg = sum(times)/len(times) |
|
|
print("Average style transfer time: %.4f seconds" % (avg)) |