rahul7star commited on
Commit
d419dc7
·
verified ·
1 Parent(s): 6a9ccac

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +21 -4
app_quant_latent.py CHANGED
@@ -513,22 +513,39 @@ import torch
513
  # Helper: Safe latent extractor
514
  # --------------------------
515
  def safe_get_latents(pipe, height, width, generator, device, LOGS):
 
 
 
 
516
  try:
 
 
 
 
 
 
 
 
517
  latents = pipe.prepare_latents(
518
  batch_size=1,
519
- num_channels=getattr(pipe.unet, "in_channels", 4),
520
  height=height,
521
  width=width,
522
  dtype=torch.float32,
523
  device=device,
524
- generator=generator
525
  )
 
526
  LOGS.append(f"🔹 Latents shape: {latents.shape}, dtype: {latents.dtype}, device: {latents.device}")
527
  return latents
528
  except Exception as e:
529
  LOGS.append(f"⚠️ Latent extraction failed: {e}")
530
- return torch.randn((1, 4, height // 8, width // 8), generator=generator, device=device)
531
-
 
 
 
 
532
 
533
  # --------------------------
534
  # Main generation function (kept exactly as your logic)
 
513
  # Helper: Safe latent extractor
514
  # --------------------------
515
  def safe_get_latents(pipe, height, width, generator, device, LOGS):
516
+ """
517
+ Safely prepare latents for any ZImagePipeline variant.
518
+ Returns latents tensor, logs issues instead of failing.
519
+ """
520
  try:
521
+ # Determine number of channels
522
+ num_channels = 4 # default fallback
523
+ if hasattr(pipe, "unet") and hasattr(pipe.unet, "in_channels"):
524
+ num_channels = pipe.unet.in_channels
525
+ elif hasattr(pipe, "vae") and hasattr(pipe.vae, "latent_channels"):
526
+ num_channels = pipe.vae.latent_channels # some pipelines define this
527
+ LOGS.append(f"🔹 Using num_channels={num_channels} for latents")
528
+
529
  latents = pipe.prepare_latents(
530
  batch_size=1,
531
+ num_channels_latents=num_channels,
532
  height=height,
533
  width=width,
534
  dtype=torch.float32,
535
  device=device,
536
+ generator=generator,
537
  )
538
+
539
  LOGS.append(f"🔹 Latents shape: {latents.shape}, dtype: {latents.dtype}, device: {latents.device}")
540
  return latents
541
  except Exception as e:
542
  LOGS.append(f"⚠️ Latent extraction failed: {e}")
543
+ # fallback: guess a safe shape
544
+ fallback_channels = 16 # try standard default for ZImage pipelines
545
+ latents = torch.randn((1, fallback_channels, height // 8, width // 8),
546
+ generator=generator, device=device)
547
+ LOGS.append(f"🔹 Using fallback random latents shape: {latents.shape}")
548
+ return latents
549
 
550
  # --------------------------
551
  # Main generation function (kept exactly as your logic)