rahul7star commited on
Commit
f7c01ff
Β·
verified Β·
1 Parent(s): dca5b2f

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +48 -20
app_quant_latent.py CHANGED
@@ -554,24 +554,47 @@ def safe_get_latents(pipe, height, width, generator, device, LOGS):
554
  # --------------------------
555
  # Main generation function (kept exactly as your logic)
556
  # --------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  @spaces.GPU
558
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
559
  LOGS = []
560
  device = "cuda"
561
  generator = torch.Generator(device).manual_seed(int(seed))
562
 
563
- # placeholders
564
  placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
565
  latent_gallery = []
566
  final_gallery = []
567
 
568
  try:
569
- # --- Try advanced latent mode ---
570
  try:
571
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
572
 
 
573
  for i, t in enumerate(pipe.scheduler.timesteps):
574
- # Step-wise denoising
575
  with torch.no_grad():
576
  noise_pred = pipe.unet(
577
  latents,
@@ -580,28 +603,34 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
580
  )["sample"]
581
  latents = pipe.scheduler.step(noise_pred, t, latents)["prev_sample"]
582
 
583
- # Convert latent to preview image
584
- try:
585
- latent_img = latent_to_image(latents, pipe.vae)[0]
586
- except Exception:
587
- latent_img = placeholder
588
-
589
- latent_gallery.append(latent_img)
590
-
591
- # Yield intermediate update: latents updated, final gallery empty
592
- yield None, latent_gallery, final_gallery, LOGS
593
 
594
  # Decode final image
595
  final_img = pipe.decode_latents(latents)[0]
596
  final_gallery.append(final_img)
597
  LOGS.append("βœ… Advanced latent pipeline succeeded.")
598
- yield final_img, latent_gallery, final_gallery, LOGS
 
 
 
 
 
 
 
 
 
599
 
600
  except Exception as e:
601
  LOGS.append(f"⚠️ Advanced latent mode failed: {e}")
602
  LOGS.append("πŸ” Switching to standard pipeline...")
603
 
604
- # Standard pipeline fallback
605
  try:
606
  output = pipe(
607
  prompt=prompt,
@@ -613,22 +642,21 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
613
  )
614
  final_img = output.images[0]
615
  final_gallery.append(final_img)
616
- latent_gallery.append(final_img) # optionally show in latent gallery as last step
617
  LOGS.append("βœ… Standard pipeline succeeded.")
618
- yield final_img, latent_gallery, final_gallery, LOGS
619
 
620
  except Exception as e2:
621
  LOGS.append(f"❌ Standard pipeline failed: {e2}")
622
  final_gallery.append(placeholder)
623
  latent_gallery.append(placeholder)
624
- yield placeholder, latent_gallery, final_gallery, LOGS
625
 
626
  except Exception as e:
627
  LOGS.append(f"❌ Total failure: {e}")
628
  final_gallery.append(placeholder)
629
  latent_gallery.append(placeholder)
630
- yield placeholder, latent_gallery, final_gallery, LOGS
631
-
632
 
633
 
634
 
 
554
  # --------------------------
555
  # Main generation function (kept exactly as your logic)
556
  # --------------------------
557
+ from huggingface_hub import HfApi, HfFolder
558
+ import torch
559
+ import os
560
+
561
+ HF_REPO_ID = "rahul7star/Zstudio-latent" # Model repo
562
+ HF_TOKEN = HfFolder.get_token() # Make sure you are logged in via `huggingface-cli login`
563
+
564
+ def upload_latents_to_hf(latent_dict, filename="latents.pt"):
565
+ local_path = f"/tmp/{filename}"
566
+ torch.save(latent_dict, local_path)
567
+ try:
568
+ api = HfApi()
569
+ api.upload_file(
570
+ path_or_fileobj=local_path,
571
+ path_in_repo=filename,
572
+ repo_id=HF_REPO_ID,
573
+ token=HF_TOKEN,
574
+ repo_type="model" # since this is a model repo
575
+ )
576
+ os.remove(local_path)
577
+ return f"https://huggingface.co/{HF_REPO_ID}/resolve/main/{filename}"
578
+ except Exception as e:
579
+ os.remove(local_path)
580
+ raise e
581
+
582
  @spaces.GPU
583
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
584
  LOGS = []
585
  device = "cuda"
586
  generator = torch.Generator(device).manual_seed(int(seed))
587
 
 
588
  placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
589
  latent_gallery = []
590
  final_gallery = []
591
 
592
  try:
 
593
  try:
594
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
595
 
596
+ # Step-wise denoising
597
  for i, t in enumerate(pipe.scheduler.timesteps):
 
598
  with torch.no_grad():
599
  noise_pred = pipe.unet(
600
  latents,
 
603
  )["sample"]
604
  latents = pipe.scheduler.step(noise_pred, t, latents)["prev_sample"]
605
 
606
+ # Convert latent to preview image every few steps
607
+ if i % max(1, len(pipe.scheduler.timesteps)//10) == 0 or i == len(pipe.scheduler.timesteps)-1:
608
+ try:
609
+ latent_img = latent_to_image(latents, pipe.vae)[0]
610
+ except Exception:
611
+ latent_img = placeholder
612
+ latent_gallery.append(latent_img)
613
+ yield None, latent_gallery, LOGS # yield intermediate latents
 
 
614
 
615
  # Decode final image
616
  final_img = pipe.decode_latents(latents)[0]
617
  final_gallery.append(final_img)
618
  LOGS.append("βœ… Advanced latent pipeline succeeded.")
619
+
620
+ # Save latents to dict and upload to HF
621
+ latent_dict = {"latents": latents.cpu(), "prompt": prompt, "seed": seed}
622
+ try:
623
+ hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_{seed}.pt")
624
+ LOGS.append(f"πŸ”Ή Latents uploaded: {hf_url}")
625
+ except Exception as e:
626
+ LOGS.append(f"⚠️ Failed to upload latents: {e}")
627
+
628
+ yield final_img, latent_gallery, LOGS
629
 
630
  except Exception as e:
631
  LOGS.append(f"⚠️ Advanced latent mode failed: {e}")
632
  LOGS.append("πŸ” Switching to standard pipeline...")
633
 
 
634
  try:
635
  output = pipe(
636
  prompt=prompt,
 
642
  )
643
  final_img = output.images[0]
644
  final_gallery.append(final_img)
645
+ latent_gallery.append(final_img) # fallback latent preview
646
  LOGS.append("βœ… Standard pipeline succeeded.")
647
+ yield final_img, latent_gallery, LOGS
648
 
649
  except Exception as e2:
650
  LOGS.append(f"❌ Standard pipeline failed: {e2}")
651
  final_gallery.append(placeholder)
652
  latent_gallery.append(placeholder)
653
+ yield placeholder, latent_gallery, LOGS
654
 
655
  except Exception as e:
656
  LOGS.append(f"❌ Total failure: {e}")
657
  final_gallery.append(placeholder)
658
  latent_gallery.append(placeholder)
659
+ yield placeholder, latent_gallery, LOGS
 
660
 
661
 
662