coralLight's picture
add 4k inference
8e2dc9a
"""SAMPLING ONLY."""
import torch
from dpm_solver_v3 import NoiseScheduleVP, model_wrapper, DPM_Solver_v3
from uni_pc import UniPC
from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
class DPMSolverv3Sampler:
def __init__(self, stats_dir, pipe, steps, guidance_scale, **kwargs):
super().__init__()
self.model = pipe
to_torch = lambda x: x.clone().detach().to(torch.float32).to(pipe.device)
DTYPE = torch.float32 # torch.float16 works as well, but pictures seem to be a bit worse
device = "cuda"
noise_scheduler = pipe.scheduler
alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=DTYPE)
self.alphas_cumprod = alpha_schedule #to_torch(model.alphas_cumprod)
self.device = device
self.guidance_scale = guidance_scale
self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
assert stats_dir is not None, f"No statistics file found in {stats_dir}."
print("Use statistics", stats_dir)
self.dpm_solver_v3 = DPM_Solver_v3(
statistics_dir=stats_dir,
noise_schedule=self.ns,
steps=steps,
t_start=None,
t_end=None,
skip_type="customed_time_karras",
degenerated=False,
device=self.device,
)
self.steps = steps
@torch.no_grad()
def apply_free_unet(self):
register_free_upblock2d(self.model, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
register_free_crossattn_upblock2d(self.model, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
@torch.no_grad()
def stop_free_unet(self):
register_free_upblock2d(self.model, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
register_free_crossattn_upblock2d(self.model, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
@torch.no_grad()
def sample(
self,
batch_size,
shape,
conditioning=None,
x_T=None,
unconditional_conditioning=None,
use_corrector=False,
half=False,
start_free_u_step=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
cond_in = torch.cat([unconditional_conditioning, conditioning])
# extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.guidance_scale}
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
if x_T is None:
img = torch.randn(size, device=self.device)
else:
img = x_T
if conditioning is None:
model_fn = model_wrapper(
lambda x, t, c: self.model.unet(x, t, encoder_hidden_states=c).sample,
self.ns,
model_type="noise",
guidance_type="uncond",
)
ORDER = 3
else:
model_fn = model_wrapper(
lambda x, t, c: self.model.unet(x, t, encoder_hidden_states=c).sample,
self.ns,
model_type="noise",
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=self.guidance_scale,
)
if self.steps == 8:
ORDER = 2
else:
ORDER = 1
x = self.dpm_solver_v3.sample(
img,
model_fn,
order=ORDER,
p_pseudo=False,
c_pseudo=True,
lower_order_final=True,
use_corrector=use_corrector,
start_free_u_step=start_free_u_step,
free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
half=half,
)
return x.to(self.device), None
class UniPCSampler:
def __init__(self
, pipe
, model_closure
, steps
, guidance_scale,denoise_to_zero=False
, need_fp16_discrete_method = False
, ultilize_vae_in_fp16 = False
, is_high_resoulution = True
, skip_type="customed_time_karras"
, force_not_use_afs=False
, **kwargs):
super().__init__()
# self.model = pipe
self.model = model_closure(pipe)
self.pipe = pipe
self.need_fp16_discrete_method = need_fp16_discrete_method
# to_torch = lambda x: x.clone().detach().to(torch.float32).to(pipe.device)
DTYPE = self.pipe.unet.dtype # torch.float16 works as well, but pictures seem to be a bit worse
device = self.pipe.device
noise_scheduler = pipe.scheduler
alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=DTYPE)
self.alphas_cumprod = alpha_schedule #to_torch(model.alphas_cumprod)
self.device = device
self.guidance_scale = guidance_scale
self.use_afs = steps <= 8 and is_high_resoulution and not force_not_use_afs
self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
self.unipc_solver = UniPC(
noise_schedule=self.ns,
steps=steps,
t_start=None,
t_end=None,
skip_type=skip_type,
degenerated=False,
use_afs=self.use_afs,
device=self.device,
denoise_to_zero=denoise_to_zero,
need_fp16_discrete_method = self.need_fp16_discrete_method,
ultilize_vae_in_fp16 = ultilize_vae_in_fp16,
is_high_resoulution=is_high_resoulution,
)
self.steps = steps
@torch.no_grad()
def apply_free_unet(self):
register_free_upblock2d(self.pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.2)
register_free_crossattn_upblock2d(self.pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.2)
@torch.no_grad()
def stop_free_unet(self):
register_free_upblock2d(self.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
register_free_crossattn_upblock2d(self.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
@torch.no_grad()
def sample(
self,
batch_size,
shape,
conditioning=None,
x_T=None,
unconditional_conditioning=None,
use_corrector=False,
half=False,
start_free_u_step=None,
xl_preprocess_closure=None,
npnet=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
new_img = None
if xl_preprocess_closure is not None:
prompt_embeds, cond_kwargs = xl_preprocess_closure(pipe=self.pipe,prompts = conditioning, need_cfg=True, device=self.device,negative_prompts=unconditional_conditioning)
if x_T is None:
img = torch.randn(size, device=self.device)
else:
img = x_T
if xl_preprocess_closure is not None and npnet is not None:
c, _ = prompt_embeds
c = c.unsqueeze(0) # add dummy dimension for npnet
new_img = npnet(img, c)
if conditioning is None:
model_fn = model_wrapper(
lambda x, t, c: self.model(x, t, c),
self.ns,
model_type="noise",
guidance_type="uncond",
)
ORDER = 3
else:
model_fn = model_wrapper(
lambda x, t, c: self.model(x, t, c),
self.ns,
model_type="noise",
guidance_type="classifier-free",
condition=conditioning if xl_preprocess_closure is None else prompt_embeds,
unconditional_condition=unconditional_conditioning if xl_preprocess_closure is None else cond_kwargs,
guidance_scale=self.guidance_scale,
)
if self.steps >= 7:
ORDER = 2
else:
ORDER = 1
x, full_cache = self.unipc_solver.sample(
x=img,
model_fn=model_fn,
order=ORDER,
use_corrector=use_corrector,
lower_order_final=True,
start_free_u_step=start_free_u_step,
free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
npnet_x=new_img if new_img is not None else None,
npnet_scale=self.guidance_scale if new_img is not None else None,
half=half,
)
return x.to(self.device), full_cache
@torch.no_grad()
def sample_mix(
self,
batch_size,
shape,
conditioning=None,
x_T=None,
unconditional_conditioning=None,
use_corrector=False,
half=False,
start_free_u_step=None,
xl_preprocess_closure=None,
npnet=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
if xl_preprocess_closure is not None:
prompt_embeds, cond_kwargs = xl_preprocess_closure(pipe=self.pipe,prompts = conditioning, need_cfg=True, device=self.device,negative_prompts=unconditional_conditioning)
if x_T is None:
img = torch.randn(size, device=self.device)
else:
img = x_T
if xl_preprocess_closure is not None and npnet is not None:
c, _ = prompt_embeds
c = c.unsqueeze(0) # add dummy dimension for npnet
img = npnet(img, c)
if conditioning is None:
model_fn = model_wrapper(
lambda x, t, c: self.model(x, t, c),
self.ns,
model_type="noise",
guidance_type="uncond",
)
ORDER = 3
else:
model_fn = model_wrapper(
lambda x, t, c: self.model(x, t, c),
self.ns,
model_type="noise",
guidance_type="classifier-free",
condition=conditioning if xl_preprocess_closure is None else prompt_embeds,
unconditional_condition=unconditional_conditioning if xl_preprocess_closure is None else cond_kwargs,
guidance_scale=self.guidance_scale,
)
if self.steps >= 8 and not self.need_fp16_discrete_method:
ORDER = 2
else:
ORDER = 1
x, full_cache = self.unipc_solver.sample_mix(
x=img,
model_fn=model_fn,
order=ORDER,
use_corrector=use_corrector,
lower_order_final=True,
start_free_u_step=start_free_u_step,
free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
half=half,
)
return x.to(self.device), full_cache