rizavelioglu commited on
Commit
05d50b7
·
1 Parent(s): 5c1a861

enable TinyAE

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -56,7 +56,7 @@ class VAETester:
56
  # "dc-ae-f32c32-sana-1.0": AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers").to(self.device),
57
  "FLUX.1-Kontext": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", subfolder="vae").to(self.device),
58
  "FLUX.2": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.2-dev", subfolder="vae").to(self.device),
59
- # "FLUX.2-TinyAutoEncoder": AutoModel.from_pretrained("fal/FLUX.2-Tiny-AutoEncoder", trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device),
60
  }
61
  # Define the desired order of models
62
  order = [
@@ -72,31 +72,28 @@ class VAETester:
72
  # "dc-ae-f32c32-sana-1.0",
73
  "FLUX.1-Kontext",
74
  "FLUX.2",
75
- # "FLUX.2-TinyAutoEncoder",
76
  ]
77
 
78
  # Construct the vae_models dictionary in the specified order
79
  return {name: {"vae": vaes[name], "dtype": torch.bfloat16 if name == "FLUX.2-TinyAutoEncoder" else torch.float32} for name in order}
80
 
81
- def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float):
82
  """Process image through a single VAE model"""
83
  dtype = model_config["dtype"]
 
84
  img_transformed = self.input_transform(img).to(dtype).to(self.device).unsqueeze(0)
85
  original_base = self.base_transform(img).cpu()
86
 
87
- # Start timer
88
  start_time = time.time()
89
-
90
- vae = model_config["vae"]
91
  with torch.no_grad():
92
- if isinstance(vae, AutoModel):
93
  encoded = vae.encode(img_transformed, return_dict=False)
94
  decoded = vae.decode(encoded, return_dict=False)
95
  else:
96
  encoded = vae.encode(img_transformed).latent_dist.sample()
97
  decoded = vae.decode(encoded).sample
98
-
99
- # End timer
100
  processing_time = time.time() - start_time
101
 
102
  decoded_transformed = self.output_transform(decoded.squeeze(0).to(torch.float32)).cpu()
@@ -111,8 +108,8 @@ class VAETester:
111
  def process_all_models(self, img: torch.Tensor, tolerance: float):
112
  """Process image through all configured VAEs"""
113
  results = {}
114
- for name, model_config in self.vae_models.items():
115
- results[name] = self.process_image(img, model_config, tolerance)
116
  return results
117
 
118
  @spaces.GPU(duration=20)
 
56
  # "dc-ae-f32c32-sana-1.0": AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers").to(self.device),
57
  "FLUX.1-Kontext": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", subfolder="vae").to(self.device),
58
  "FLUX.2": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.2-dev", subfolder="vae").to(self.device),
59
+ "FLUX.2-TinyAutoEncoder": AutoModel.from_pretrained("fal/FLUX.2-Tiny-AutoEncoder", trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device),
60
  }
61
  # Define the desired order of models
62
  order = [
 
72
  # "dc-ae-f32c32-sana-1.0",
73
  "FLUX.1-Kontext",
74
  "FLUX.2",
75
+ "FLUX.2-TinyAutoEncoder",
76
  ]
77
 
78
  # Construct the vae_models dictionary in the specified order
79
  return {name: {"vae": vaes[name], "dtype": torch.bfloat16 if name == "FLUX.2-TinyAutoEncoder" else torch.float32} for name in order}
80
 
81
+ def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float, vae_name: str):
82
  """Process image through a single VAE model"""
83
  dtype = model_config["dtype"]
84
+ vae = model_config["vae"]
85
  img_transformed = self.input_transform(img).to(dtype).to(self.device).unsqueeze(0)
86
  original_base = self.base_transform(img).cpu()
87
 
88
+ # Time the encoding-decoding process
89
  start_time = time.time()
 
 
90
  with torch.no_grad():
91
+ if vae_name == "FLUX.2-TinyAutoEncoder":
92
  encoded = vae.encode(img_transformed, return_dict=False)
93
  decoded = vae.decode(encoded, return_dict=False)
94
  else:
95
  encoded = vae.encode(img_transformed).latent_dist.sample()
96
  decoded = vae.decode(encoded).sample
 
 
97
  processing_time = time.time() - start_time
98
 
99
  decoded_transformed = self.output_transform(decoded.squeeze(0).to(torch.float32)).cpu()
 
108
  def process_all_models(self, img: torch.Tensor, tolerance: float):
109
  """Process image through all configured VAEs"""
110
  results = {}
111
+ for vae_name, model_config in self.vae_models.items():
112
+ results[vae_name] = self.process_image(img, model_config, tolerance, vae_name)
113
  return results
114
 
115
  @spaces.GPU(duration=20)