rahul7star commited on
Commit
c9f04dd
·
verified ·
1 Parent(s): 2314c25

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +57 -8
app_quant_latent.py CHANGED
@@ -247,22 +247,76 @@ log_system_stats("AFTER PIPELINE BUILD")
247
 
248
 
249
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  @spaces.GPU
251
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
252
-
253
  LOGS = []
254
  latents = None
255
  image = None
256
  gallery = []
257
 
258
- # safe placeholder image
259
  placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
260
 
261
  try:
262
  generator = torch.Generator(device).manual_seed(int(seed))
263
 
264
  # -------------------------------
265
- # Try advanced latent extractor
266
  # -------------------------------
267
  try:
268
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
@@ -285,9 +339,6 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
285
  LOGS.append(f"⚠️ Latent mode failed: {e}")
286
  LOGS.append("🔁 Switching to standard pipeline...")
287
 
288
- # -------------------------------
289
- # Standard generation fallback
290
- # -------------------------------
291
  try:
292
  output = pipe(
293
  prompt=prompt,
@@ -306,13 +357,11 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
306
  image = placeholder
307
  gallery = [image]
308
 
309
- # final safe return
310
  return image, gallery, LOGS
311
 
312
  except Exception as e:
313
  LOGS.append(f"❌ Total failure: {e}")
314
  return placeholder, [placeholder], LOGS
315
-
316
  @spaces.GPU
317
  def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):
318
  """
 
247
 
248
 
249
  from PIL import Image
250
+ import torch
251
+
252
+ # --------------------------
253
+ # Helper: Safe latent extractor
254
+ # --------------------------
255
+ def safe_get_latents(pipe, height, width, generator, device, LOGS):
256
+ """
257
+ Attempts multiple ways to get latents.
258
+ Returns a valid tensor even if pipeline hides UNet.
259
+ """
260
+ # Try official prepare_latents
261
+ try:
262
+ if hasattr(pipe, "unet") and hasattr(pipe.unet, "in_channels"):
263
+ num_channels = pipe.unet.in_channels
264
+ latents = pipe.prepare_latents(
265
+ batch_size=1,
266
+ num_channels=num_channels,
267
+ height=height,
268
+ width=width,
269
+ dtype=torch.float32,
270
+ device=device,
271
+ generator=generator
272
+ )
273
+ LOGS.append("✅ Latents extracted using official prepare_latents.")
274
+ return latents
275
+ except Exception as e:
276
+ LOGS.append(f"⚠️ Official latent extraction failed: {e}")
277
+
278
+ # Try hidden internal attribute
279
+ try:
280
+ if hasattr(pipe, "_default_latents"):
281
+ LOGS.append("⚠️ Using hidden _default_latents.")
282
+ return pipe._default_latents
283
+ except:
284
+ pass
285
+
286
+ # Fallback: raw Gaussian tensor
287
+ try:
288
+ LOGS.append("⚠️ Using raw Gaussian latents fallback.")
289
+ return torch.randn(
290
+ (1, 4, height // 8, width // 8),
291
+ generator=generator,
292
+ device=device,
293
+ dtype=torch.float32
294
+ )
295
+ except Exception as e:
296
+ LOGS.append(f"⚠️ Gaussian fallback failed: {e}")
297
+
298
+ LOGS.append("❗ Using CPU hard fallback latents.")
299
+ return torch.randn((1, 4, height // 8, width // 8))
300
+
301
+
302
+ # --------------------------
303
+ # Main generation function
304
+ # --------------------------
305
  @spaces.GPU
306
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
 
307
  LOGS = []
308
  latents = None
309
  image = None
310
  gallery = []
311
 
312
+ # placeholder image if all fails
313
  placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
314
 
315
  try:
316
  generator = torch.Generator(device).manual_seed(int(seed))
317
 
318
  # -------------------------------
319
+ # Try advanced latent extraction
320
  # -------------------------------
321
  try:
322
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
 
339
  LOGS.append(f"⚠️ Latent mode failed: {e}")
340
  LOGS.append("🔁 Switching to standard pipeline...")
341
 
 
 
 
342
  try:
343
  output = pipe(
344
  prompt=prompt,
 
357
  image = placeholder
358
  gallery = [image]
359
 
 
360
  return image, gallery, LOGS
361
 
362
  except Exception as e:
363
  LOGS.append(f"❌ Total failure: {e}")
364
  return placeholder, [placeholder], LOGS
 
365
  @spaces.GPU
366
  def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):
367
  """