add adjustable size, random sampling from style ds
Browse files- test_api.py +18 -12
test_api.py
CHANGED
|
@@ -39,10 +39,11 @@ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
|
|
| 39 |
mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
|
| 40 |
return decoder(mix_enc)
|
| 41 |
|
| 42 |
-
def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
|
| 43 |
content_pths = [Path(f) for f in glob(content_dir+'/*')]
|
|
|
|
| 44 |
|
| 45 |
-
assert
|
| 46 |
|
| 47 |
# Load AdaIN model
|
| 48 |
vgg = torch.load(vgg_pth)
|
|
@@ -57,17 +58,22 @@ def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_n
|
|
| 57 |
times = []
|
| 58 |
|
| 59 |
style_ds = load_dataset(style_dataset_pth, split="train")
|
| 60 |
-
# do i need to stick a dataloader around this? idk
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
style_tensor = t(style_img).unsqueeze(0).to(device)
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# Start time
|
| 73 |
tic = time.perf_counter()
|
|
@@ -83,7 +89,7 @@ def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_n
|
|
| 83 |
times.append(toc-tic)
|
| 84 |
|
| 85 |
# Save image
|
| 86 |
-
out_pth = out_dir + content_pth.stem + '_style_' + str(
|
| 87 |
out_pth += content_pth.suffix
|
| 88 |
save_image(out_tensor, out_pth)
|
| 89 |
|
|
|
|
| 39 |
mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
|
| 40 |
return decoder(mix_enc)
|
| 41 |
|
| 42 |
+
def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, dataset_size=100, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
|
| 43 |
content_pths = [Path(f) for f in glob(content_dir+'/*')]
|
| 44 |
+
num_content_imgs = len(content_pths)
|
| 45 |
|
| 46 |
+
assert num_content_imgs > 0, 'Failed to load content image'
|
| 47 |
|
| 48 |
# Load AdaIN model
|
| 49 |
vgg = torch.load(vgg_pth)
|
|
|
|
| 58 |
times = []
|
| 59 |
|
| 60 |
style_ds = load_dataset(style_dataset_pth, split="train")
|
|
|
|
| 61 |
|
| 62 |
+
if num_content_imgs * len(style_ds) > dataset_size:
|
| 63 |
+
num_style_per_content = int(np.ceil(dataset_size / num_content_imgs))
|
| 64 |
+
else:
|
| 65 |
+
num_style_per_content = len(style_ds)
|
|
|
|
| 66 |
|
| 67 |
+
for content_pth in content_pths:
|
| 68 |
+
content_img = Image.open(content_pth)
|
| 69 |
+
content_tensor = t(content_img).unsqueeze(0).to(device)
|
| 70 |
+
indices = random.sample(range(len(style_ds)), num_style_per_content)
|
| 71 |
+
|
| 72 |
+
for idx in indices:
|
| 73 |
+
style_img = style_ds[idx]['image']
|
| 74 |
+
if style_img.mode not in ("RGB", "L"):
|
| 75 |
+
style_img = style_img.convert("RGB")
|
| 76 |
+
style_tensor = t(style_img).unsqueeze(0).to(device)
|
| 77 |
|
| 78 |
# Start time
|
| 79 |
tic = time.perf_counter()
|
|
|
|
| 89 |
times.append(toc-tic)
|
| 90 |
|
| 91 |
# Save image
|
| 92 |
+
out_pth = out_dir + content_pth.stem + '_style_' + str(idx) + '_alpha' + str(alpha)
|
| 93 |
out_pth += content_pth.suffix
|
| 94 |
save_image(out_tensor, out_pth)
|
| 95 |
|