Diffusers documentation
Cosmos3OmniTransformer
Cosmos3OmniTransformer
A Mixture-of-Transformer (MoT) joint vision-language transformer introduced as part of NVIDIA’s Cosmos3 world foundation model family. The model runs two parallel computation pathways over a packed joint sequence:
- a causal “understanding” pathway that self-attends over text tokens with causal masking, and
- a bi-directional “generation” pathway that cross-attends from generation tokens (vision + optional sound latents) over the full understanding-plus-generation key/value set.
The two pathways share the same hidden size and number of layers but maintain separate Q/K/V/O projections, MLPs, and RMSNorm parameters, which is what makes the architecture a Mixture-of-Transformer rather than a standard Mixture-of-Experts. Position information is supplied through a 3D multimodal RoPE (mRoPE) that interleaves temporal / height / width frequencies for video latents and reuses the temporal axis for text and audio.
The model can be loaded as follows.
import torch
from diffusers import Cosmos3OmniTransformer
transformer = Cosmos3OmniTransformer.from_pretrained(
"nvidia/Cosmos3-Nano", subfolder="transformer", torch_dtype=torch.bfloat16
)Cosmos3OmniTransformer
class diffusers.Cosmos3OmniTransformer
< source >( attention_bias: bool = False attention_dropout: float = 0.0 dtype: str = 'bfloat16' head_dim: int = 128 hidden_size: int = 4096 intermediate_size: int = 12288 base_fps: int = 24 enable_fps_modulation: bool = True latent_channel: int = 48 unified_3d_mrope_reset_spatial_ids: bool = True unified_3d_mrope_temporal_modality_margin: int = 15000 latent_patch_size: int = 2 num_attention_heads: int = 32 num_hidden_layers: int = 36 num_key_value_heads: int = 8 patch_latent_dim: int = 192 rms_norm_eps: float = 1e-06 rope_scaling: dict | None = None rope_theta: float = 5000000.0 sound_dim: int | None = None sound_gen: bool = False sound_latent_fps: float = 25.0 timestep_scale: float = 0.001 vocab_size: int = 151936 )
forward
< source >( input_ids: Tensor text_indexes: Tensor position_ids: Tensor und_len: int sequence_length: int vision_tokens: list vision_token_shapes: list vision_sequence_indexes: Tensor vision_mse_loss_indexes: Tensor vision_timesteps: Tensor vision_noisy_frame_indexes: list sound_tokens: list[torch.Tensor] | None = None sound_token_shapes: list[tuple[int, int, int]] | None = None sound_sequence_indexes: torch.Tensor | None = None sound_mse_loss_indexes: torch.Tensor | None = None sound_timesteps: torch.Tensor | None = None sound_noisy_frame_indexes: list[torch.Tensor] | None = None )
Parameters
- input_ids — Text token IDs placed at
text_indexesin the joint sequence. - text_indexes — Indices of text tokens in the joint sequence.
- position_ids —
[3, sequence_length]mRoPE position IDs for the full joint sequence. - und_len — Length of the causal text (understanding) prefix; generation tokens follow.
- sequence_length — Total length of the joint packed sequence.
- vision_tokens — Per-item vision latent tensors before patchify.
- vision_token_shapes — Patch grid shapes
(T, H, W)per vision item. - vision_sequence_indexes — Indices of vision tokens in the joint sequence.
- vision_mse_loss_indexes — Indices used to read vision predictions after the backbone.
- vision_timesteps — Per-patch diffusion timesteps for vision tokens.
- vision_noisy_frame_indexes — Noisy frame indices per vision item.
- sound_tokens — Optional sound latent tensors before packing.
- sound_token_shapes — Optional patch grid shapes for sound items.
- sound_sequence_indexes — Optional indices of sound tokens in the joint sequence.
- sound_mse_loss_indexes — Optional indices used to read sound predictions.
- sound_timesteps — Optional per-token diffusion timesteps for sound.
- sound_noisy_frame_indexes — Optional noisy frame indices per sound item.
Run a full denoising-step forward pass.