tidalove commited on
Commit
4a3a637
·
verified ·
1 Parent(s): 02582df

Update test_api.py

Browse files
Files changed (1) hide show
  1. test_api.py +3 -3
test_api.py CHANGED
@@ -61,7 +61,7 @@ def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_n
61
  style_ds = load_dataset(style_dataset_pth, split="train")
62
  # do i need to stick a dataloader around this? idk
63
 
64
- for style_item in style_ds:
65
  style_img = style_item['image']
66
  print(style_img)
67
  style_tensor = t(style_img).unsqueeze(0).to(device)
@@ -80,11 +80,11 @@ def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, vgg_pth='vgg_n
80
  # End time
81
  toc = time.perf_counter()
82
  print("Content: " + content_pth.stem + ". Style: " \
83
- + style_pth.stem + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
84
  times.append(toc-tic)
85
 
86
  # Save image
87
- out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(alpha)
88
  out_pth += content_pth.suffix
89
  save_image(out_tensor, out_pth)
90
 
 
61
  style_ds = load_dataset(style_dataset_pth, split="train")
62
  # do i need to stick a dataloader around this? idk
63
 
64
+ for style_idx, style_item in enumerate(style_ds):
65
  style_img = style_item['image']
66
  print(style_img)
67
  style_tensor = t(style_img).unsqueeze(0).to(device)
 
80
  # End time
81
  toc = time.perf_counter()
82
  print("Content: " + content_pth.stem + ". Style: " \
83
+ + str(style_idx) + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
84
  times.append(toc-tic)
85
 
86
  # Save image
87
+ out_pth = out_dir + content_pth.stem + '_style_' + str(style_idx) + '_alpha' + str(alpha)
88
  out_pth += content_pth.suffix
89
  save_image(out_tensor, out_pth)
90