rahul7star commited on
Commit
9c8674b
·
verified ·
1 Parent(s): 6cd5c6f

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +17 -21
app_quant_latent.py CHANGED
@@ -691,56 +691,52 @@ def generate_image_all_latents(prompt, height, width, steps, seed, guidance_scal
691
  @spaces.GPU
692
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
693
  LOGS = []
694
- device = "cuda" if torch.cuda.is_available() else "cpu"
695
  generator = torch.Generator(device).manual_seed(int(seed))
696
 
697
  placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
698
  latent_gallery = []
699
  final_gallery = []
700
 
701
- # --- Generate latent previews in a loop ---
702
  try:
703
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
704
-
705
- # always keep latents float32 until decode
706
- latents = latents.float()
707
 
708
  num_previews = min(10, steps)
709
  preview_steps = torch.linspace(0, 1, num_previews)
710
 
711
- for i, alpha in enumerate(preview_steps):
712
  try:
713
  with torch.no_grad():
714
- # --- simulate progression like Z-Image Turbo ---
715
- preview_latent = latents * alpha + torch.randn_like(latents) * (1 - alpha)
716
 
717
- # 🛠 FIX: move to same device as VAE and match dtype
718
  preview_latent = preview_latent.to(pipe.vae.device).to(pipe.vae.dtype)
719
 
720
- # Decode latent
721
- decoded = pipe.vae.decode(preview_latent).sample # [1,3,H,W]
722
- decoded = (decoded / 2 + 0.5).clamp(0, 1)
723
 
724
- # Convert to PIL
725
- decoded = decoded[0].permute(1, 2, 0).cpu().numpy() # HWC
726
- latent_img = Image.fromarray((decoded * 255).astype("uint8"))
 
 
727
 
728
  except Exception as e:
729
  LOGS.append(f"⚠️ Latent preview decode failed: {e}")
730
  latent_img = placeholder
731
 
732
  latent_gallery.append(latent_img)
733
- yield None, latent_gallery, LOGS # update Gradio with intermediate preview
734
-
735
- # Optionally, you can store/upload last few latents here for later
736
- # last_latents = latents[-4:].cpu()
737
 
738
  except Exception as e:
739
  LOGS.append(f"⚠️ Latent generation failed: {e}")
740
  latent_gallery.append(placeholder)
741
  yield None, latent_gallery, LOGS
742
 
743
- # --- Final image: standard pipeline ---
744
  try:
745
  output = pipe(
746
  prompt=prompt,
@@ -752,7 +748,7 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
752
  )
753
  final_img = output.images[0]
754
  final_gallery.append(final_img)
755
- latent_gallery.append(final_img) # fallback preview if needed
756
  LOGS.append("✅ Standard pipeline succeeded.")
757
  yield final_img, latent_gallery, LOGS
758
 
 
691
  @spaces.GPU
692
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
693
  LOGS = []
694
+ device = "cuda"
695
  generator = torch.Generator(device).manual_seed(int(seed))
696
 
697
  placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
698
  latent_gallery = []
699
  final_gallery = []
700
 
701
+ # --- Generate latent previews ---
702
  try:
703
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
704
+ latents = latents.float() # keep float32 until decode
 
 
705
 
706
  num_previews = min(10, steps)
707
  preview_steps = torch.linspace(0, 1, num_previews)
708
 
709
+ for alpha in preview_steps:
710
  try:
711
  with torch.no_grad():
712
+ # Simulate denoising progression like Z-Image Turbo
713
+ preview_latent = latents * alpha + latents * 0 # optional: simple progression
714
 
715
+ # Move to same device and dtype as VAE
716
  preview_latent = preview_latent.to(pipe.vae.device).to(pipe.vae.dtype)
717
 
718
+ # Decode
719
+ decoded = pipe.vae.decode(preview_latent, return_dict=False)[0]
 
720
 
721
+ # Convert to PIL following same logic as final image
722
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
723
+ decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy()
724
+ decoded = (decoded * 255).round().astype("uint8")
725
+ latent_img = Image.fromarray(decoded[0])
726
 
727
  except Exception as e:
728
  LOGS.append(f"⚠️ Latent preview decode failed: {e}")
729
  latent_img = placeholder
730
 
731
  latent_gallery.append(latent_img)
732
+ yield None, latent_gallery, LOGS
 
 
 
733
 
734
  except Exception as e:
735
  LOGS.append(f"⚠️ Latent generation failed: {e}")
736
  latent_gallery.append(placeholder)
737
  yield None, latent_gallery, LOGS
738
 
739
+ # --- Final image: untouched ---
740
  try:
741
  output = pipe(
742
  prompt=prompt,
 
748
  )
749
  final_img = output.images[0]
750
  final_gallery.append(final_img)
751
+ latent_gallery.append(final_img) # fallback preview
752
  LOGS.append("✅ Standard pipeline succeeded.")
753
  yield final_img, latent_gallery, LOGS
754