tidalove commited on
Commit
68ca7bf
·
verified ·
1 Parent(s): 2713d3b

add adjustable size, random sampling from style ds

Browse files
Files changed (1) hide show
  1. test_api.py +18 -12
test_api.py CHANGED
@@ -39,10 +39,11 @@ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
39
  mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
40
  return decoder(mix_enc)
41
 
42
- def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
43
  content_pths = [Path(f) for f in glob(content_dir+'/*')]
 
44
 
45
- assert len(content_pths) > 0, 'Failed to load content image'
46
 
47
  # Load AdaIN model
48
  vgg = torch.load(vgg_pth)
@@ -57,17 +58,22 @@ def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_n
57
  times = []
58
 
59
  style_ds = load_dataset(style_dataset_pth, split="train")
60
- # do i need to stick a dataloader around this? idk
61
 
62
- for style_idx, style_item in enumerate(style_ds):
63
- style_img = style_item['image']
64
- if style_img.mode not in ("RGB", "L"):
65
- style_img = style_img.convert("RGB")
66
- style_tensor = t(style_img).unsqueeze(0).to(device)
67
 
68
- for content_pth in content_pths:
69
- content_img = Image.open(content_pth)
70
- content_tensor = t(content_img).unsqueeze(0).to(device)
 
 
 
 
 
 
 
71
 
72
  # Start time
73
  tic = time.perf_counter()
@@ -83,7 +89,7 @@ def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_n
83
  times.append(toc-tic)
84
 
85
  # Save image
86
- out_pth = out_dir + content_pth.stem + '_style_' + str(style_idx) + '_alpha' + str(alpha)
87
  out_pth += content_pth.suffix
88
  save_image(out_tensor, out_pth)
89
 
 
39
  mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
40
  return decoder(mix_enc)
41
 
42
+ def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, dataset_size=100, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
43
  content_pths = [Path(f) for f in glob(content_dir+'/*')]
44
+ num_content_imgs = len(content_pths)
45
 
46
+ assert num_content_imgs > 0, 'Failed to load content image'
47
 
48
  # Load AdaIN model
49
  vgg = torch.load(vgg_pth)
 
58
  times = []
59
 
60
  style_ds = load_dataset(style_dataset_pth, split="train")
 
61
 
62
+ if num_content_imgs * len(style_ds) > dataset_size:
63
+ num_style_per_content = int(np.ceil(dataset_size / num_content_imgs))
64
+ else:
65
+ num_style_per_content = len(style_ds)
 
66
 
67
+ for content_pth in content_pths:
68
+ content_img = Image.open(content_pth)
69
+ content_tensor = t(content_img).unsqueeze(0).to(device)
70
+ indices = random.sample(range(len(style_ds)), num_style_per_content)
71
+
72
+ for idx in indices:
73
+ style_img = style_ds[idx]['image']
74
+ if style_img.mode not in ("RGB", "L"):
75
+ style_img = style_img.convert("RGB")
76
+ style_tensor = t(style_img).unsqueeze(0).to(device)
77
 
78
  # Start time
79
  tic = time.perf_counter()
 
89
  times.append(toc-tic)
90
 
91
  # Save image
92
+ out_pth = out_dir + content_pth.stem + '_style_' + str(idx) + '_alpha' + str(alpha)
93
  out_pth += content_pth.suffix
94
  save_image(out_tensor, out_pth)
95