add color control to test_interpolte.py
Browse files- .gitignore +1 -1
- README.md +1 -1
- test_interpolate.py +12 -3
.gitignore
CHANGED
|
@@ -2,4 +2,4 @@
|
|
| 2 |
/__pycache__/
|
| 3 |
|
| 4 |
#Ignore results
|
| 5 |
-
/results
|
|
|
|
| 2 |
/__pycache__/
|
| 3 |
|
| 4 |
#Ignore results
|
| 5 |
+
/results*/
|
README.md
CHANGED
|
@@ -73,7 +73,7 @@ optional arguments:
|
|
| 73 |
To test style transfer interpolation, run the script test_interpolate.py. Specify `--style_image` with multiple paths separated by comma. Specify `--interpolation_weights` to interpolate once. All outputs are saved in `./results_interpolate/`. Specify `--grid_pth` to interpolate with different built-in weights and provide 4 style images.
|
| 74 |
|
| 75 |
```
|
| 76 |
-
$ python
|
| 77 |
|
| 78 |
optional arguments:
|
| 79 |
-h, --help show this help message and exit
|
|
|
|
| 73 |
To test style transfer interpolation, run the script test_interpolate.py. Specify `--style_image` with multiple paths separated by comma. Specify `--interpolation_weights` to interpolate once. All outputs are saved in `./results_interpolate/`. Specify `--grid_pth` to interpolate with different built-in weights and provide 4 style images.
|
| 74 |
|
| 75 |
```
|
| 76 |
+
$ python test_interpolate.py --content_image $IMG --style_image $STYLE $WEIGHT --cuda
|
| 77 |
|
| 78 |
optional arguments:
|
| 79 |
-h, --help show this help message and exit
|
test_interpolate.py
CHANGED
|
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
| 7 |
from AdaIN import AdaINNet
|
| 8 |
from PIL import Image
|
| 9 |
from torchvision.utils import save_image
|
| 10 |
-
from utils import adaptive_instance_normalization, transform, Range, grid_image
|
| 11 |
from glob import glob
|
| 12 |
|
| 13 |
parser = argparse.ArgumentParser()
|
|
@@ -19,6 +19,7 @@ parser.add_argument('--interpolation_weights', type=str, help='Weights of interp
|
|
| 19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
| 20 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
|
| 21 |
if use grid mode, provide 4 style images')
|
|
|
|
| 22 |
args = parser.parse_args()
|
| 23 |
assert args.content_image
|
| 24 |
assert args.style_image
|
|
@@ -106,7 +107,13 @@ def main():
|
|
| 106 |
style_tensor = []
|
| 107 |
for style_pth in style_pths:
|
| 108 |
img = Image.open(style_pth)
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
| 111 |
|
| 112 |
for inter_weight in inter_weights:
|
|
@@ -117,7 +124,9 @@ def main():
|
|
| 117 |
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
| 118 |
|
| 119 |
# Save results
|
| 120 |
-
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight)
|
|
|
|
|
|
|
| 121 |
save_image(out_tensor, out_pth)
|
| 122 |
|
| 123 |
if args.grid_pth:
|
|
|
|
| 7 |
from AdaIN import AdaINNet
|
| 8 |
from PIL import Image
|
| 9 |
from torchvision.utils import save_image
|
| 10 |
+
from utils import adaptive_instance_normalization, transform,linear_histogram_matching, Range, grid_image
|
| 11 |
from glob import glob
|
| 12 |
|
| 13 |
parser = argparse.ArgumentParser()
|
|
|
|
| 19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
| 20 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
|
| 21 |
if use grid mode, provide 4 style images')
|
| 22 |
+
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
| 23 |
args = parser.parse_args()
|
| 24 |
assert args.content_image
|
| 25 |
assert args.style_image
|
|
|
|
| 107 |
style_tensor = []
|
| 108 |
for style_pth in style_pths:
|
| 109 |
img = Image.open(style_pth)
|
| 110 |
+
if args.color_control:
|
| 111 |
+
img = transform([512,512])(img).unsqueeze(0)
|
| 112 |
+
img = linear_histogram_matching(content_tensor,img)
|
| 113 |
+
img = img.squeeze(0)
|
| 114 |
+
style_tensor.append(img)
|
| 115 |
+
else:
|
| 116 |
+
style_tensor.append(transform([512, 512])(img))
|
| 117 |
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
| 118 |
|
| 119 |
for inter_weight in inter_weights:
|
|
|
|
| 124 |
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
| 125 |
|
| 126 |
# Save results
|
| 127 |
+
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight)
|
| 128 |
+
if args.color_control: out_pth += '_colorcontrol'
|
| 129 |
+
out_pth += content_pth.suffix
|
| 130 |
save_image(out_tensor, out_pth)
|
| 131 |
|
| 132 |
if args.grid_pth:
|