File size: 1,043 Bytes
bb10560
 
 
b22b80e
bb10560
b05966a
bb10560
 
5ca41bd
afa2559
b05966a
 
 
bb10560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

from datetime import datetime

import gradio as gr
import spaces
import torch
from diffusers import FluxPipeline

from aoti import aoti_load_

# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/Flux.1-Dev", torch_dtype=torch.bfloat16
).to(device)
pipeline.transformer.fuse_qkv_projections()
aoti_load_(pipeline.transformer, "sayakpaul/flux-dev-aot", "flux-dev-aot.pt2")


@spaces.GPU
def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)):
    generator = torch.Generator(device='cuda').manual_seed(42)
    t0 = datetime.now()
    output = pipeline(
        prompt=prompt,
        num_inference_steps=28,
        generator=generator,
    )
    return [(output.images[0], f'{(datetime.now() - t0).total_seconds():.2f}s')]


gr.Interface(
    fn=generate_image,
    inputs=gr.Text(label="Prompt"),
    outputs=gr.Gallery(),
    examples=["A cat playing with a ball of yarn"],
    cache_examples=False,
).launch()