Spaces:
Running
on
Zero
Running
on
Zero
rizavelioglu
commited on
Commit
·
05d50b7
1
Parent(s):
5c1a861
enable TinyAE
Browse files
app.py
CHANGED
|
@@ -56,7 +56,7 @@ class VAETester:
|
|
| 56 |
# "dc-ae-f32c32-sana-1.0": AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers").to(self.device),
|
| 57 |
"FLUX.1-Kontext": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", subfolder="vae").to(self.device),
|
| 58 |
"FLUX.2": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.2-dev", subfolder="vae").to(self.device),
|
| 59 |
-
|
| 60 |
}
|
| 61 |
# Define the desired order of models
|
| 62 |
order = [
|
|
@@ -72,31 +72,28 @@ class VAETester:
|
|
| 72 |
# "dc-ae-f32c32-sana-1.0",
|
| 73 |
"FLUX.1-Kontext",
|
| 74 |
"FLUX.2",
|
| 75 |
-
|
| 76 |
]
|
| 77 |
|
| 78 |
# Construct the vae_models dictionary in the specified order
|
| 79 |
return {name: {"vae": vaes[name], "dtype": torch.bfloat16 if name == "FLUX.2-TinyAutoEncoder" else torch.float32} for name in order}
|
| 80 |
|
| 81 |
-
def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float):
|
| 82 |
"""Process image through a single VAE model"""
|
| 83 |
dtype = model_config["dtype"]
|
|
|
|
| 84 |
img_transformed = self.input_transform(img).to(dtype).to(self.device).unsqueeze(0)
|
| 85 |
original_base = self.base_transform(img).cpu()
|
| 86 |
|
| 87 |
-
#
|
| 88 |
start_time = time.time()
|
| 89 |
-
|
| 90 |
-
vae = model_config["vae"]
|
| 91 |
with torch.no_grad():
|
| 92 |
-
if
|
| 93 |
encoded = vae.encode(img_transformed, return_dict=False)
|
| 94 |
decoded = vae.decode(encoded, return_dict=False)
|
| 95 |
else:
|
| 96 |
encoded = vae.encode(img_transformed).latent_dist.sample()
|
| 97 |
decoded = vae.decode(encoded).sample
|
| 98 |
-
|
| 99 |
-
# End timer
|
| 100 |
processing_time = time.time() - start_time
|
| 101 |
|
| 102 |
decoded_transformed = self.output_transform(decoded.squeeze(0).to(torch.float32)).cpu()
|
|
@@ -111,8 +108,8 @@ class VAETester:
|
|
| 111 |
def process_all_models(self, img: torch.Tensor, tolerance: float):
|
| 112 |
"""Process image through all configured VAEs"""
|
| 113 |
results = {}
|
| 114 |
-
for
|
| 115 |
-
results[
|
| 116 |
return results
|
| 117 |
|
| 118 |
@spaces.GPU(duration=20)
|
|
|
|
| 56 |
# "dc-ae-f32c32-sana-1.0": AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers").to(self.device),
|
| 57 |
"FLUX.1-Kontext": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", subfolder="vae").to(self.device),
|
| 58 |
"FLUX.2": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.2-dev", subfolder="vae").to(self.device),
|
| 59 |
+
"FLUX.2-TinyAutoEncoder": AutoModel.from_pretrained("fal/FLUX.2-Tiny-AutoEncoder", trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device),
|
| 60 |
}
|
| 61 |
# Define the desired order of models
|
| 62 |
order = [
|
|
|
|
| 72 |
# "dc-ae-f32c32-sana-1.0",
|
| 73 |
"FLUX.1-Kontext",
|
| 74 |
"FLUX.2",
|
| 75 |
+
"FLUX.2-TinyAutoEncoder",
|
| 76 |
]
|
| 77 |
|
| 78 |
# Construct the vae_models dictionary in the specified order
|
| 79 |
return {name: {"vae": vaes[name], "dtype": torch.bfloat16 if name == "FLUX.2-TinyAutoEncoder" else torch.float32} for name in order}
|
| 80 |
|
| 81 |
+
def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float, vae_name: str):
|
| 82 |
"""Process image through a single VAE model"""
|
| 83 |
dtype = model_config["dtype"]
|
| 84 |
+
vae = model_config["vae"]
|
| 85 |
img_transformed = self.input_transform(img).to(dtype).to(self.device).unsqueeze(0)
|
| 86 |
original_base = self.base_transform(img).cpu()
|
| 87 |
|
| 88 |
+
# Time the encoding-decoding process
|
| 89 |
start_time = time.time()
|
|
|
|
|
|
|
| 90 |
with torch.no_grad():
|
| 91 |
+
if vae_name == "FLUX.2-TinyAutoEncoder":
|
| 92 |
encoded = vae.encode(img_transformed, return_dict=False)
|
| 93 |
decoded = vae.decode(encoded, return_dict=False)
|
| 94 |
else:
|
| 95 |
encoded = vae.encode(img_transformed).latent_dist.sample()
|
| 96 |
decoded = vae.decode(encoded).sample
|
|
|
|
|
|
|
| 97 |
processing_time = time.time() - start_time
|
| 98 |
|
| 99 |
decoded_transformed = self.output_transform(decoded.squeeze(0).to(torch.float32)).cpu()
|
|
|
|
| 108 |
def process_all_models(self, img: torch.Tensor, tolerance: float):
|
| 109 |
"""Process image through all configured VAEs"""
|
| 110 |
results = {}
|
| 111 |
+
for vae_name, model_config in self.vae_models.items():
|
| 112 |
+
results[vae_name] = self.process_image(img, model_config, tolerance, vae_name)
|
| 113 |
return results
|
| 114 |
|
| 115 |
@spaces.GPU(duration=20)
|