rahul7star commited on
Commit
6cd5c6f
·
verified ·
1 Parent(s): fd1b8e0

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +11 -15
app_quant_latent.py CHANGED
@@ -691,7 +691,7 @@ def generate_image_all_latents(prompt, height, width, steps, seed, guidance_scal
691
  @spaces.GPU
692
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
693
  LOGS = []
694
- device = "cuda"
695
  generator = torch.Generator(device).manual_seed(int(seed))
696
 
697
  placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
@@ -711,33 +711,29 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
711
  for i, alpha in enumerate(preview_steps):
712
  try:
713
  with torch.no_grad():
714
- # Simulate progression
715
  preview_latent = latents * alpha + torch.randn_like(latents) * (1 - alpha)
 
716
  # 🛠 FIX: move to same device as VAE and match dtype
717
  preview_latent = preview_latent.to(pipe.vae.device).to(pipe.vae.dtype)
718
 
719
  # Decode latent
720
- latent_img_tensor = pipe.vae.decode(preview_latent).sample # [1,3,H,W]
721
- latent_img_tensor = (latent_img_tensor / 2 + 0.5).clamp(0, 1)
722
 
723
  # Convert to PIL
724
- latent_img_tensor = latent_img_tensor[0].permute(1, 2, 0).cpu().numpy() # HWC
725
- latent_img = Image.fromarray((latent_img_tensor * 255).astype("uint8"))
726
 
727
  except Exception as e:
728
  LOGS.append(f"⚠️ Latent preview decode failed: {e}")
729
  latent_img = placeholder
730
 
731
  latent_gallery.append(latent_img)
732
- yield None, latent_gallery, LOGS
733
-
734
- # Upload latents (optional)
735
- # latent_dict = {"latents": latents.cpu(), "prompt": prompt, "seed": seed}
736
- # try:
737
- # hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_{seed}.pt")
738
- # LOGS.append(f"🔹 Latents uploaded: {hf_url}")
739
- # except Exception as e:
740
- # LOGS.append(f"⚠️ Failed to upload latents: {e}")
741
 
742
  except Exception as e:
743
  LOGS.append(f"⚠️ Latent generation failed: {e}")
 
691
  @spaces.GPU
692
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
693
  LOGS = []
694
+ device = "cuda" if torch.cuda.is_available() else "cpu"
695
  generator = torch.Generator(device).manual_seed(int(seed))
696
 
697
  placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
 
711
  for i, alpha in enumerate(preview_steps):
712
  try:
713
  with torch.no_grad():
714
+ # --- simulate progression like Z-Image Turbo ---
715
  preview_latent = latents * alpha + torch.randn_like(latents) * (1 - alpha)
716
+
717
  # 🛠 FIX: move to same device as VAE and match dtype
718
  preview_latent = preview_latent.to(pipe.vae.device).to(pipe.vae.dtype)
719
 
720
  # Decode latent
721
+ decoded = pipe.vae.decode(preview_latent).sample # [1,3,H,W]
722
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
723
 
724
  # Convert to PIL
725
+ decoded = decoded[0].permute(1, 2, 0).cpu().numpy() # HWC
726
+ latent_img = Image.fromarray((decoded * 255).astype("uint8"))
727
 
728
  except Exception as e:
729
  LOGS.append(f"⚠️ Latent preview decode failed: {e}")
730
  latent_img = placeholder
731
 
732
  latent_gallery.append(latent_img)
733
+ yield None, latent_gallery, LOGS # update Gradio with intermediate preview
734
+
735
+ # Optionally, you can store/upload last few latents here for later
736
+ # last_latents = latents[-4:].cpu()
 
 
 
 
 
737
 
738
  except Exception as e:
739
  LOGS.append(f"⚠️ Latent generation failed: {e}")