Image2Video / app_quant.py
rahul7star's picture
Update app_quant.py
85d3b9b verified
raw
history blame
8.77 kB
import torch
import spaces
import gradio as gr
import sys
import platform
import os
import diffusers
import transformers
import peft
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from diffusers import ZImagePipeline, AutoModel
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
# ============================================================
# LOGGING BUFFER
# ============================================================
LOGS = ""
def log(msg):
global LOGS
print(msg)
LOGS += msg + "\n"
return msg
# ============================================================
# ENVIRONMENT INFO
# ============================================================
log("===================================================")
log("πŸ” Z-IMAGE-TURBO DEBUGGING + DETAILED TRANSFORMER INSIGHTS")
log("===================================================\n")
log(f"πŸ“Œ PYTHON VERSION : {sys.version.replace(chr(10), ' ')}")
log(f"πŸ“Œ PLATFORM : {platform.platform()}")
log(f"πŸ“Œ TORCH VERSION : {torch.__version__}")
log(f"πŸ“Œ TRANSFORMERS VERSION : {transformers.__version__}")
log(f"πŸ“Œ DIFFUSERS VERSION : {diffusers.__version__}")
log(f"πŸ“Œ PEFT VERSION : {peft.__version__}")
log(f"πŸ“Œ CUDA AVAILABLE : {torch.cuda.is_available()}")
if torch.cuda.is_available():
log(f"πŸ“Œ GPU NAME : {torch.cuda.get_device_name(0)}")
log(f"πŸ“Œ GPU CAPABILITY : {torch.cuda.get_device_capability(0)}")
log(f"πŸ“Œ GPU MEMORY (TOTAL) : {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")
log(f"πŸ“Œ FLASH ATTENTION : {torch.backends.cuda.flash_sdp_enabled()}")
else:
raise RuntimeError("❌ CUDA is REQUIRED but not available.")
device = "cuda"
gpu_id = 0
# ============================================================
# MODEL SETTINGS
# ============================================================
model_cache = "./weights/"
model_id = "Tongyi-MAI/Z-Image-Turbo"
torch_dtype = torch.bfloat16
USE_CPU_OFFLOAD = False
log("\n===================================================")
log("🧠 MODEL CONFIGURATION")
log("===================================================")
log(f"Model ID : {model_id}")
log(f"Model Cache Directory : {model_cache}")
log(f"torch_dtype : {torch_dtype}")
log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}")
# ============================================================
# LOAD TRANSFORMER BLOCK
# ============================================================
log("\n===================================================")
log("πŸ”§ LOADING TRANSFORMER BLOCK")
log("===================================================")
quantization_config = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
)
log("4-bit Quantization Config (Transformer):")
log(str(quantization_config))
transformer = AutoModel.from_pretrained(
model_id,
cache_dir=model_cache,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=device,
)
log("βœ… Transformer block loaded successfully.")
# ------------------------------------------------------------
# TRANSFORMER INSIGHTS
# ------------------------------------------------------------
log("πŸ” Transformer Architecture Details:")
log(f"Number of Transformer Modules : {len(transformer.transformer_blocks)}")
for i, block in enumerate(transformer.transformer_blocks):
log(f" Block {i}: {block.__class__.__name__}")
# Log attention type if possible
attn_type = getattr(block, "attn", None)
if attn_type:
log(f" Attention: {attn_type.__class__.__name__}")
# Check for FlashAttention usage if attribute exists
flash_enabled = getattr(attn_type, "flash", None)
log(f" FlashAttention Enabled? : {flash_enabled}")
log(f"Hidden size: {transformer.config.hidden_size}")
log(f"Number of attention heads: {transformer.config.num_attention_heads}")
log(f"Number of layers: {transformer.config.num_hidden_layers}")
log(f"Intermediate size: {transformer.config.intermediate_size}")
if USE_CPU_OFFLOAD:
transformer = transformer.to("cpu")
# ============================================================
# LOAD TEXT ENCODER
# ============================================================
log("\n===================================================")
log("πŸ”§ LOADING TEXT ENCODER")
log("===================================================")
quantization_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
log("4-bit Quantization Config (Text Encoder):")
log(str(quantization_config))
text_encoder = AutoModel.from_pretrained(
model_id,
cache_dir=model_cache,
subfolder="text_encoder",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=device,
)
log("βœ… Text encoder loaded successfully.")
# ------------------------------------------------------------
# TEXT ENCODER INSIGHTS
# ------------------------------------------------------------
log("πŸ” Text Encoder Architecture Details:")
log(f"Number of Transformer Modules : {len(text_encoder.transformer_blocks)}")
for i, block in enumerate(text_encoder.transformer_blocks):
log(f" Block {i}: {block.__class__.__name__}")
attn_type = getattr(block, "attn", None)
if attn_type:
log(f" Attention: {attn_type.__class__.__name__}")
flash_enabled = getattr(attn_type, "flash", None)
log(f" FlashAttention Enabled? : {flash_enabled}")
log(f"Hidden size: {text_encoder.config.hidden_size}")
log(f"Number of attention heads: {text_encoder.config.num_attention_heads}")
log(f"Number of layers: {text_encoder.config.num_hidden_layers}")
log(f"Intermediate size: {text_encoder.config.intermediate_size}")
if USE_CPU_OFFLOAD:
text_encoder = text_encoder.to("cpu")
# ============================================================
# BUILD PIPELINE
# ============================================================
log("\n===================================================")
log("πŸ”§ BUILDING Z-IMAGE-TURBO PIPELINE")
log("===================================================")
pipe = ZImagePipeline.from_pretrained(
model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch_dtype,
)
if USE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload(gpu_id=gpu_id)
log("βš™ CPU OFFLOAD ENABLED")
else:
pipe.to(device)
log("βš™ Pipeline moved to GPU")
log("βœ… Pipeline ready.")
# ============================================================
# INFERENCE FUNCTION
# ============================================================
@spaces.GPU
def generate_image(prompt, height, width, steps, seed):
global LOGS
LOGS = "" # Reset logs for this run
log("===================================================")
log("🎨 RUNNING INFERENCE")
log("===================================================")
log(f"Prompt : {prompt}")
log(f"Resolution : {width} x {height}")
log(f"Steps : {steps}")
log(f"Seed : {seed}")
generator = torch.Generator(device).manual_seed(seed)
out = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=0.0,
generator=generator,
)
log("βœ… Inference Finished")
return out.images[0], LOGS
# ============================================================
# GRADIO UI
# ============================================================
with gr.Blocks(title="Z-Image-Turbo Generator") as demo:
gr.Markdown("# **πŸš€ Z-Image-Turbo β€” Transformer Deep Logs**")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image")
height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps")
seed = gr.Slider(0, 999999, value=42, step=1, label="Seed")
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Output Image")
logs_panel = gr.Textbox(label="πŸ“œ Transformer Logs", lines=25, interactive=False)
btn.click(
generate_image,
inputs=[prompt, height, width, steps, seed],
outputs=[output_image, logs_panel],
)
demo.launch()