Diffusers documentation
Cosmos3OmniTransformer
Cosmos3OmniTransformer
A Mixture-of-Transformer (MoT) joint vision-language transformer introduced as part of NVIDIA’s Cosmos3 world foundation model family. The model runs two parallel computation pathways over a packed joint sequence:
- a causal “understanding” pathway that self-attends over text tokens with causal masking, and
- a bi-directional “generation” pathway that cross-attends from generation tokens (vision + optional sound latents) over the full understanding-plus-generation key/value set.
The two pathways share the same hidden size and number of layers but maintain separate Q/K/V/O projections, MLPs, and RMSNorm parameters, which is what makes the architecture a Mixture-of-Transformer rather than a standard Mixture-of-Experts. Position information is supplied through a 3D multimodal RoPE (mRoPE) that interleaves temporal / height / width frequencies for video latents and reuses the temporal axis for text and audio.
The model can be loaded as follows.
import torch
from diffusers import Cosmos3OmniTransformer
transformer = Cosmos3OmniTransformer.from_pretrained(
"nvidia/Cosmos3-Nano", subfolder="transformer", torch_dtype=torch.bfloat16
)Cosmos3OmniTransformer
class diffusers.Cosmos3OmniTransformer
< source >( attention_bias: bool = Falseattention_dropout: float = 0.0dtype: str = 'bfloat16'head_dim: int = 128hidden_size: int = 4096intermediate_size: int = 12288base_fps: int = 24enable_fps_modulation: bool = Truelatent_channel: int = 48unified_3d_mrope_reset_spatial_ids: bool = Trueunified_3d_mrope_temporal_modality_margin: int = 15000latent_patch_size: int = 2num_attention_heads: int = 32num_hidden_layers: int = 36num_key_value_heads: int = 8patch_latent_dim: int = 192rms_norm_eps: float = 1e-06rope_scaling: dict | None = Nonerope_theta: float = 5000000.0action_dim: int | None = Noneaction_gen: bool = Falsenum_embodiment_domains: int = 32sound_dim: int | None = Nonesound_gen: bool = Falsesound_latent_fps: float = 25.0timestep_scale: float = 0.001vocab_size: int = 151936 )
forward
< source >( input_ids: Tensortext_indexes: Tensorposition_ids: Tensorund_len: intsequence_length: intvision_tokens: listvision_token_shapes: listvision_sequence_indexes: Tensorvision_mse_loss_indexes: Tensorvision_timesteps: Tensorvision_noisy_frame_indexes: listsound_tokens: list[torch.Tensor] | None = Nonesound_token_shapes: list[tuple[int, int, int]] | None = Nonesound_sequence_indexes: torch.Tensor | None = Nonesound_mse_loss_indexes: torch.Tensor | None = Nonesound_timesteps: torch.Tensor | None = Nonesound_noisy_frame_indexes: list[torch.Tensor] | None = Noneaction_tokens: list[torch.Tensor] | None = Noneaction_token_shapes: list[tuple[int, int, int]] | None = Noneaction_sequence_indexes: torch.Tensor | None = Noneaction_mse_loss_indexes: torch.Tensor | None = Noneaction_timesteps: torch.Tensor | None = Noneaction_noisy_frame_indexes: list[torch.Tensor] | None = Noneaction_domain_ids: list[torch.Tensor] | None = None )
Parameters
- input_ids — Text token IDs placed at
text_indexesin the joint sequence. - text_indexes — Indices of text tokens in the joint sequence.
- position_ids —
[3, sequence_length]mRoPE position IDs for the full joint sequence. - und_len — Length of the causal text (understanding) prefix; generation tokens follow.
- sequence_length — Total length of the joint packed sequence.
- vision_tokens — Per-item vision latent tensors before patchify.
- vision_token_shapes — Patch grid shapes
(T, H, W)per vision item. - vision_sequence_indexes — Indices of vision tokens in the joint sequence.
- vision_mse_loss_indexes — Indices used to read vision predictions after the backbone.
- vision_timesteps — Per-patch diffusion timesteps for vision tokens.
- vision_noisy_frame_indexes — Noisy frame indices per vision item.
- sound_tokens — Optional sound latent tensors before packing.
- sound_token_shapes — Optional patch grid shapes for sound items.
- sound_sequence_indexes — Optional indices of sound tokens in the joint sequence.
- sound_mse_loss_indexes — Optional indices used to read sound predictions.
- sound_timesteps — Optional per-token diffusion timesteps for sound.
- sound_noisy_frame_indexes — Optional noisy frame indices per sound item.
- action_tokens — Optional action latent tensors before packing.
- action_token_shapes — Optional patch grid shapes
(T, H, W)per action item. - action_sequence_indexes — Optional indices of action tokens in the joint sequence.
- action_mse_loss_indexes — Optional indices used to read action predictions after the backbone.
- action_timesteps — Optional per-token diffusion timesteps for action tokens.
- action_noisy_frame_indexes — Optional noisy frame indices per action item.
- action_domain_ids — Optional per-item domain IDs selecting the action head weights.
Run a full denoising-step forward pass.