| from typing import Any, Dict, List, Optional, Tuple, Union |
| import torch, math |
| import torch.nn as nn |
| from einops import rearrange |
| from diffsynth.core.attention import attention_forward |
| from diffsynth.core.gradient import gradient_checkpoint_forward |
| from diffsynth.models.flux2_dit import apply_rotary_emb, Flux2PosEmbed |
| from diffsynth.models.general_modules import get_timestep_embedding |
| from PIL import Image |
| import numpy as np |
|
|
|
|
| class AdaLayerNormContinuous(nn.Module): |
| def __init__(self, dim_in, dim_out, eps=1e-6): |
| super().__init__() |
| self.linear = nn.Linear(dim_in, dim_out * 2, bias=False) |
| self.norm = nn.LayerNorm(dim_in, eps=eps, elementwise_affine=False, bias=False) |
|
|
| def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: |
| scale, shift = self.linear(torch.nn.functional.silu(conditioning_embedding)).chunk(2, dim=1) |
| x = self.norm(x) * (1 + scale) + shift |
| return x |
|
|
|
|
| class Flux2FeedForward(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.linear_in = nn.Linear(dim, dim*3*2, bias=False) |
| self.linear_out = nn.Linear(dim*3, dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = self.linear_in(x).chunk(2, dim=-1) |
| x = torch.nn.functional.silu(x1) * x2 |
| x = self.linear_out(x) |
| return x |
|
|
|
|
| class Flux2TransformerBlock(nn.Module): |
| def __init__(self, dim, num_heads, eps=1e-6): |
| super().__init__() |
| self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) |
| self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) |
|
|
| self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) |
| self.img_ff = Flux2FeedForward(dim) |
| self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) |
| self.txt_ff = Flux2FeedForward(dim) |
|
|
| self.num_heads = num_heads |
| self.img_to_qkv = torch.nn.Linear(dim, 3 * dim, bias=False) |
| self.img_norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps) |
| self.img_norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps) |
| self.img_to_out = torch.nn.Linear(dim, dim, bias=False) |
| self.txt_to_qkv = torch.nn.Linear(dim, 3 * dim, bias=False) |
| self.txt_norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps) |
| self.txt_norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps) |
| self.txt_to_out = torch.nn.Linear(dim, dim, bias=False) |
|
|
| def attention(self, img: torch.Tensor, txt: torch.Tensor, image_rotary_emb: torch.Tensor, **kwargs) -> torch.Tensor: |
| img_q, img_k, img_v = self.img_to_qkv(img).chunk(3, dim=-1) |
| txt_q, txt_k, txt_v = self.txt_to_qkv(txt).chunk(3, dim=-1) |
| img_q, img_k, img_v, txt_q, txt_k, txt_v = tuple(map(lambda x: x.unflatten(-1, (self.num_heads, -1)), (img_q, img_k, img_v, txt_q, txt_k, txt_v))) |
| img_q = self.img_norm_q(img_q) |
| img_k = self.img_norm_k(img_k) |
| txt_q = self.txt_norm_q(txt_q) |
| txt_k = self.txt_norm_k(txt_k) |
|
|
| q = torch.cat([txt_q, img_q], dim=1) |
| k = torch.cat([txt_k, img_k], dim=1) |
| v = torch.cat([txt_v, img_v], dim=1) |
| q = apply_rotary_emb(q, image_rotary_emb, sequence_dim=1) |
| k = apply_rotary_emb(k, image_rotary_emb, sequence_dim=1) |
|
|
| img = attention_forward(q, k, v, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s (n d)") |
| txt, img = img.split_with_sizes([txt.shape[1], img.shape[1] - txt.shape[1]], dim=1) |
| txt = self.txt_to_out(txt) |
| img = self.img_to_out(img) |
| return img, txt, (k, v) |
|
|
| def forward(self, img, txt, temb_mod_params_img, temb_mod_params_txt, image_rotary_emb): |
| (img_shift_msa, img_scale_msa, img_gate_msa), (img_shift_mlp, img_scale_mlp, img_gate_mlp) = temb_mod_params_img |
| (txt_shift_msa, txt_scale_msa, txt_gate_msa), (txt_shift_mlp, txt_scale_mlp, txt_gate_mlp) = temb_mod_params_txt |
|
|
| norm_img = (1 + img_scale_msa) * self.img_norm1(img) + img_shift_msa |
| norm_txt = (1 + txt_scale_msa) * self.txt_norm1(txt) + txt_shift_msa |
| img_attn_out, txt_attn_out, kv_cache = self.attention(norm_img, norm_txt, image_rotary_emb) |
|
|
| img = img + img_gate_msa * img_attn_out |
| norm_img = self.img_norm2(img) * (1 + img_scale_mlp) + img_shift_mlp |
| img = img + img_gate_mlp * self.img_ff(norm_img) |
|
|
| txt = txt + txt_gate_msa * txt_attn_out |
| norm_txt = self.txt_norm2(txt) * (1 + txt_scale_mlp) + txt_shift_mlp |
| txt = txt + txt_gate_mlp * self.txt_ff(norm_txt) |
| return txt, img, kv_cache |
|
|
|
|
| class Flux2SingleTransformerBlock(nn.Module): |
| def __init__(self, dim, num_heads, eps: float = 1e-6): |
| super().__init__() |
| self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) |
| self.dim = dim |
| self.num_heads = num_heads |
| self.norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps, elementwise_affine=True) |
| self.norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps, elementwise_affine=True) |
| self.to_qkv_mlp_proj = torch.nn.Linear(dim, dim * 3 + dim * 3 * 2, bias=False) |
| self.to_out = torch.nn.Linear(dim + dim * 3, dim, bias=False) |
|
|
| def attention(self, x: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: |
| x = self.to_qkv_mlp_proj(x) |
| qkv, mlp_x = torch.split(x, [3 * self.dim, self.dim * 3 * 2], dim=-1) |
| q, k, v = tuple(map(lambda x: x.unflatten(-1, (self.num_heads, -1)), qkv.chunk(3, dim=-1))) |
|
|
| q = self.norm_q(q) |
| k = self.norm_k(k) |
| q = apply_rotary_emb(q, image_rotary_emb, sequence_dim=1) |
| k = apply_rotary_emb(k, image_rotary_emb, sequence_dim=1) |
| x = attention_forward(q, k, v, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s (n d)") |
|
|
| x1, x2 = mlp_x.chunk(2, dim=-1) |
| x = torch.cat([x, torch.nn.functional.silu(x1) * x2], dim=-1) |
| x = self.to_out(x) |
| return x, (k, v) |
|
|
| def forward(self, x, temb_mod_params, image_rotary_emb): |
| mod_shift, mod_scale, mod_gate = temb_mod_params |
| norm_x = (1 + mod_scale) * self.norm(x) + mod_shift |
| attn_output, kv_cache = self.attention(x=norm_x, image_rotary_emb=image_rotary_emb,) |
| x = x + mod_gate * attn_output |
| return x, kv_cache |
|
|
|
|
| class Flux2TimestepGuidanceEmbeddings(nn.Module): |
| def __init__(self, dim_in, dim_out): |
| super().__init__() |
| self.dim_in = dim_in |
| self.timestep_embedder = torch.nn.Sequential(nn.Linear(dim_in, dim_out, bias=False), nn.SiLU(), nn.Linear(dim_out, dim_out, bias=False)) |
|
|
| def forward(self, timestep: torch.Tensor) -> torch.Tensor: |
| timesteps_proj = get_timestep_embedding(timestep, self.dim_in, flip_sin_to_cos=True, downscale_freq_shift=0) |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) |
| return timesteps_emb |
|
|
|
|
| class Flux2Modulation(nn.Module): |
| def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): |
| super().__init__() |
| self.mod_param_sets = mod_param_sets |
| self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) |
|
|
| def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: |
| mod = torch.nn.functional.silu(temb) |
| mod = self.linear(mod) |
| mod = mod.unsqueeze(1) |
| mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) |
| return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) |
|
|
|
|
| class Flux2DiTVariantModel(torch.nn.Module): |
| def __init__( |
| self, |
| patch_size: int = 1, |
| in_channels: int = 128, |
| out_channels: Optional[int] = None, |
| num_layers: int = 5, |
| num_single_layers: int = 20, |
| attention_head_dim: int = 128, |
| num_attention_heads: int = 24, |
| joint_attention_dim: int = 7680, |
| timestep_guidance_channels: int = 256, |
| axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), |
| rope_theta: int = 2000, |
| ): |
| super().__init__() |
| self.out_channels = out_channels or in_channels |
| self.inner_dim = num_attention_heads * attention_head_dim |
|
|
| |
| self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) |
|
|
| |
| self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( |
| dim_in=timestep_guidance_channels, |
| dim_out=self.inner_dim, |
| ) |
|
|
| |
| |
| self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) |
| self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) |
| |
| self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) |
|
|
| |
| self.img_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) |
| self.txt_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) |
|
|
| |
| self.transformer_blocks = nn.ModuleList([Flux2TransformerBlock(dim=self.inner_dim, num_heads=num_attention_heads) for _ in range(num_layers)]) |
|
|
| |
| self.single_transformer_blocks = nn.ModuleList([Flux2SingleTransformerBlock(dim=self.inner_dim, num_heads=num_attention_heads) for _ in range(num_single_layers)]) |
|
|
| |
| self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim) |
| self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) |
|
|
| def prepare_static_parameters(self, img, txt): |
| timestep = torch.zeros((1,), dtype=txt.dtype, device=txt.device) |
| img_ids = [] |
| for latent_id, latent in enumerate(img): |
| _, _, height, width = latent.shape |
| x_ids = torch.cartesian_prod(torch.tensor([(latent_id + 1) * 10]), torch.arange(height), torch.arange(width), torch.arange(1)) |
| img_ids.append(x_ids) |
| img_ids = torch.cat(img_ids, dim=0).to(txt.device) |
| txt_ids = torch.cartesian_prod(torch.arange(1), torch.arange(1), torch.arange(1), torch.arange(txt.shape[1])).to(txt.device) |
| return timestep, img_ids, txt_ids |
| |
| def patchify(self, img): |
| img_ = [] |
| for latent in img: |
| latent = rearrange(latent, "B C H W -> B (H W) C") |
| img_.append(latent) |
| img_ = torch.concat(img_, dim=1) |
| return img_ |
|
|
| def process_image(self, image, mask): |
| mask = mask.convert("RGB").resize(image.size) |
| mask = np.array(mask).mean(axis=-1) |
| image = np.array(image) |
| image[mask > 127] = 0 |
| return Image.fromarray(image), Image.fromarray(mask).convert("RGB") |
| |
| @torch.no_grad() |
| def process_inputs( |
| self, |
| pipe, |
| image, |
| mask, |
| prompt="Complete the content in the annotated region of the image.", |
| force_inpaint=False, |
| **kwargs |
| ): |
| masked_image, mask = self.process_image(image, mask) |
| images = [masked_image, mask] |
| pipe.load_models_to_device(["vae"]) |
| kv_cache_input_latents = [pipe.vae.encode(pipe.preprocess_image(image)) for image in images] |
| prompt_emb_unit = [unit for unit in pipe.units if unit.__class__.__name__ == "Flux2Unit_Qwen3PromptEmbedder"][0] |
| kv_cache_prompt_emb = prompt_emb_unit.process(pipe, prompt)["prompt_embeds"] |
| pipe.load_models_to_device([]) |
| return { |
| "kv_cache_input_latents": kv_cache_input_latents, |
| "kv_cache_prompt_emb": kv_cache_prompt_emb, |
| "image": image, |
| "mask": mask, |
| "force_inpaint": force_inpaint, |
| } |
|
|
| def forward( |
| self, |
| kv_cache_input_latents, |
| kv_cache_prompt_emb, |
| use_gradient_checkpointing=False, |
| use_gradient_checkpointing_offload=False, |
| image=None, |
| mask=None, |
| force_inpaint=False, |
| **kwargs, |
| ): |
| img = kv_cache_input_latents |
| txt = kv_cache_prompt_emb |
| num_txt_tokens = txt.shape[1] |
|
|
| |
| timestep, img_ids, txt_ids = self.prepare_static_parameters(img, txt) |
| img = self.patchify(img) |
|
|
| temb = self.time_guidance_embed(timestep) |
| double_stream_mod_img = self.double_stream_modulation_img(temb) |
| double_stream_mod_txt = self.double_stream_modulation_txt(temb) |
| single_stream_mod = self.single_stream_modulation(temb)[0] |
|
|
| |
| img = self.img_embedder(img) |
| txt = self.txt_embedder(txt) |
|
|
| |
| image_rotary_emb = self.pos_embed(img_ids) |
| text_rotary_emb = self.pos_embed(txt_ids) |
| concat_rotary_emb = ( |
| torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), |
| torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), |
| ) |
|
|
| |
| kv_cache = {} |
| for block_id, block in enumerate(self.transformer_blocks): |
| txt, img, kv_cache_ = gradient_checkpoint_forward( |
| block, |
| use_gradient_checkpointing=use_gradient_checkpointing, |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, |
| img=img, |
| txt=txt, |
| temb_mod_params_img=double_stream_mod_img, |
| temb_mod_params_txt=double_stream_mod_txt, |
| image_rotary_emb=concat_rotary_emb, |
| ) |
| kv_cache[f"double_{block_id}"] = kv_cache_ |
| |
| img = torch.cat([txt, img], dim=1) |
|
|
| |
| for block_id, block in enumerate(self.single_transformer_blocks): |
| img, kv_cache_ = gradient_checkpoint_forward( |
| block, |
| use_gradient_checkpointing=use_gradient_checkpointing, |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, |
| x=img, |
| temb_mod_params=single_stream_mod, |
| image_rotary_emb=concat_rotary_emb, |
| ) |
| kv_cache[f"single_{block_id}"] = kv_cache_ |
| |
| |
|
|
| |
| |
| |
|
|
| results = {"kv_cache": kv_cache} |
| if force_inpaint: |
| results.update({ |
| "input_image": image, |
| "inpaint_mask": mask, |
| "inpaint_blur_size": 1, |
| "inpaint_blur_sigma": 1, |
| }) |
| return results |
|
|
|
|
| class TrainDataProcessor: |
| def __init__(self): |
| from diffsynth.core import UnifiedDataset |
| self.image_oparator = UnifiedDataset.default_image_operator( |
| base_path="", |
| max_pixels=1024*1024, |
| height_division_factor=16, |
| width_division_factor=16, |
| ) |
|
|
| def generate_bbox(self, height, width): |
| h = torch.randint(10, height - 10, (1,)).item() |
| w = torch.randint(10, width - 10, (1,)).item() |
| x = torch.randint(0, height - h, (1,)).item() |
| y = torch.randint(0, width - w, (1,)).item() |
| return x, x + h, y, y + w |
| |
| def generate_mask(self, image): |
| image = np.array(image) |
| height, width, _ = image.shape |
| x, x_, y, y_ = self.generate_bbox(height, width) |
| image[x: x_, y: y_] = 0 |
|
|
| mask = np.zeros_like(image) |
| mask[x: x_, y: y_] = 255 |
| return Image.fromarray(image), Image.fromarray(mask) |
|
|
| def __call__(self, image, **kwargs): |
| image = self.image_oparator(image) |
| image, mask = self.generate_mask(image) |
| return { |
| "image": image, |
| "mask": mask, |
| } |
|
|
| TEMPLATE_MODEL = Flux2DiTVariantModel |
| TEMPLATE_MODEL_PATH = "model.safetensors" |
| TEMPLATE_DATA_PROCESSOR = TrainDataProcessor |
|
|