File size: 3,295 Bytes
a62af7c
 
 
 
 
a6e34e8
a62af7c
 
 
 
 
 
 
02582df
a62af7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68ca7bf
decbd98
68ca7bf
a62af7c
68ca7bf
a62af7c
 
 
 
 
 
 
 
 
 
 
 
 
667d357
a62af7c
68ca7bf
 
 
 
a62af7c
68ca7bf
 
 
 
 
 
 
 
 
 
02582df
decbd98
 
02582df
decbd98
 
 
02582df
decbd98
 
 
 
 
a62af7c
decbd98
68ca7bf
decbd98
 
a62af7c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import tempfile
import torch
import time
import numpy as np
import random
from pathlib import Path
from AdaIN import AdaINNet
from PIL import Image
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage
from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
from glob import glob
from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
	"""
	Given content image and style image, generate feature maps with encoder, apply 
	neural style transfer with adaptive instance normalization, generate output image
	with decoder

	Args:
		content_tensor (torch.FloatTensor): Content image 
		style_tensor (torch.FloatTensor): Style Image
		encoder: Encoder (vgg19) network
		decoder: Decoder network
		alpha (float, default=1.0): Weight of style image feature 
	
	Return:
		output_tensor (torch.FloatTensor): Style Transfer output image
	"""

	content_enc = encoder(content_tensor)
	style_enc = encoder(style_tensor)

	transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
	
	mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
	return decoder(mix_enc)

def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, dataset_size=100, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
	content_pths = [Path(f) for f in glob(content_dir+'/*')]
	num_content_imgs = len(content_pths)

	assert num_content_imgs > 0, 'Failed to load content image'

	# Load AdaIN model
	vgg = torch.load(vgg_pth)
	model = AdaINNet(vgg).to(device)
	model.decoder.load_state_dict(torch.load(decoder_pth))
	model.eval()
	
	# Prepare image transform
	t = transform(512)
	
	# Timer
	times = []

	style_ds = load_dataset(style_dataset_pth, split="train")

	if num_content_imgs * len(style_ds) > dataset_size:
		num_style_per_content = int(np.ceil(dataset_size / num_content_imgs))
	else:
		num_style_per_content = len(style_ds)

	for content_pth in content_pths:
		content_img = Image.open(content_pth)
		content_tensor = t(content_img).unsqueeze(0).to(device)
		indices = random.sample(range(len(style_ds)), num_style_per_content)

		for idx in indices:
			style_img = style_ds[idx]['image']
			if style_img.mode not in ("RGB", "L"):
				style_img = style_img.convert("RGB")
			style_tensor = t(style_img).unsqueeze(0).to(device)
    
			# Start time
			tic = time.perf_counter()
            
			# Execute style transfer
			with torch.no_grad():
				out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
        
			# End time
			toc = time.perf_counter()
			print("Content: " + content_pth.stem + ". Style: " \
				+ str(style_idx) + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
			times.append(toc-tic)
            
			# Save image
			out_pth = out_dir + content_pth.stem + '_style_' + str(idx) + '_alpha' + str(alpha)
			out_pth += content_pth.suffix
			save_image(out_tensor, out_pth)

	# Remove runtime of first iteration because it is flawed for some unknown reason
	if len(times) > 1:
		times.pop(0)
		avg = sum(times)/len(times)
		print("Average style transfer time: %.4f seconds" % (avg))