Spaces:
Paused
Paused
Update app_quant_latent.py
Browse files- app_quant_latent.py +21 -4
app_quant_latent.py
CHANGED
|
@@ -513,22 +513,39 @@ import torch
|
|
| 513 |
# Helper: Safe latent extractor
|
| 514 |
# --------------------------
|
| 515 |
def safe_get_latents(pipe, height, width, generator, device, LOGS):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
latents = pipe.prepare_latents(
|
| 518 |
batch_size=1,
|
| 519 |
-
num_channels
|
| 520 |
height=height,
|
| 521 |
width=width,
|
| 522 |
dtype=torch.float32,
|
| 523 |
device=device,
|
| 524 |
-
generator=generator
|
| 525 |
)
|
|
|
|
| 526 |
LOGS.append(f"🔹 Latents shape: {latents.shape}, dtype: {latents.dtype}, device: {latents.device}")
|
| 527 |
return latents
|
| 528 |
except Exception as e:
|
| 529 |
LOGS.append(f"⚠️ Latent extraction failed: {e}")
|
| 530 |
-
|
| 531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
|
| 533 |
# --------------------------
|
| 534 |
# Main generation function (kept exactly as your logic)
|
|
|
|
| 513 |
# Helper: Safe latent extractor
|
| 514 |
# --------------------------
|
| 515 |
def safe_get_latents(pipe, height, width, generator, device, LOGS):
|
| 516 |
+
"""
|
| 517 |
+
Safely prepare latents for any ZImagePipeline variant.
|
| 518 |
+
Returns latents tensor, logs issues instead of failing.
|
| 519 |
+
"""
|
| 520 |
try:
|
| 521 |
+
# Determine number of channels
|
| 522 |
+
num_channels = 4 # default fallback
|
| 523 |
+
if hasattr(pipe, "unet") and hasattr(pipe.unet, "in_channels"):
|
| 524 |
+
num_channels = pipe.unet.in_channels
|
| 525 |
+
elif hasattr(pipe, "vae") and hasattr(pipe.vae, "latent_channels"):
|
| 526 |
+
num_channels = pipe.vae.latent_channels # some pipelines define this
|
| 527 |
+
LOGS.append(f"🔹 Using num_channels={num_channels} for latents")
|
| 528 |
+
|
| 529 |
latents = pipe.prepare_latents(
|
| 530 |
batch_size=1,
|
| 531 |
+
num_channels_latents=num_channels,
|
| 532 |
height=height,
|
| 533 |
width=width,
|
| 534 |
dtype=torch.float32,
|
| 535 |
device=device,
|
| 536 |
+
generator=generator,
|
| 537 |
)
|
| 538 |
+
|
| 539 |
LOGS.append(f"🔹 Latents shape: {latents.shape}, dtype: {latents.dtype}, device: {latents.device}")
|
| 540 |
return latents
|
| 541 |
except Exception as e:
|
| 542 |
LOGS.append(f"⚠️ Latent extraction failed: {e}")
|
| 543 |
+
# fallback: guess a safe shape
|
| 544 |
+
fallback_channels = 16 # try standard default for ZImage pipelines
|
| 545 |
+
latents = torch.randn((1, fallback_channels, height // 8, width // 8),
|
| 546 |
+
generator=generator, device=device)
|
| 547 |
+
LOGS.append(f"🔹 Using fallback random latents shape: {latents.shape}")
|
| 548 |
+
return latents
|
| 549 |
|
| 550 |
# --------------------------
|
| 551 |
# Main generation function (kept exactly as your logic)
|