Diffusers documentation

Cosmos3OmniTransformer

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.39.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Cosmos3OmniTransformer

A Mixture-of-Transformer (MoT) joint vision-language transformer introduced as part of NVIDIA’s Cosmos3 world foundation model family. The model runs two parallel computation pathways over a packed joint sequence:

  • a causal “understanding” pathway that self-attends over text tokens with causal masking, and
  • a bi-directional “generation” pathway that cross-attends from generation tokens (vision + optional sound latents) over the full understanding-plus-generation key/value set.

The two pathways share the same hidden size and number of layers but maintain separate Q/K/V/O projections, MLPs, and RMSNorm parameters, which is what makes the architecture a Mixture-of-Transformer rather than a standard Mixture-of-Experts. Position information is supplied through a 3D multimodal RoPE (mRoPE) that interleaves temporal / height / width frequencies for video latents and reuses the temporal axis for text and audio.

The model can be loaded as follows.

import torch
from diffusers import Cosmos3OmniTransformer

transformer = Cosmos3OmniTransformer.from_pretrained(
    "nvidia/Cosmos3-Nano", subfolder="transformer", torch_dtype=torch.bfloat16
)

Cosmos3OmniTransformer

class diffusers.Cosmos3OmniTransformer

< >

( attention_bias: bool = Falseattention_dropout: float = 0.0dtype: str = 'bfloat16'head_dim: int = 128hidden_size: int = 4096intermediate_size: int = 12288base_fps: int = 24enable_fps_modulation: bool = Truelatent_channel: int = 48unified_3d_mrope_reset_spatial_ids: bool = Trueunified_3d_mrope_temporal_modality_margin: int = 15000latent_patch_size: int = 2num_attention_heads: int = 32num_hidden_layers: int = 36num_key_value_heads: int = 8patch_latent_dim: int = 192rms_norm_eps: float = 1e-06rope_scaling: dict | None = Nonerope_theta: float = 5000000.0action_dim: int | None = Noneaction_gen: bool = Falsenum_embodiment_domains: int = 32sound_dim: int | None = Nonesound_gen: bool = Falsesound_latent_fps: float = 25.0timestep_scale: float = 0.001vocab_size: int = 151936 )

forward

< >

( input_ids: Tensortext_indexes: Tensorposition_ids: Tensorund_len: intsequence_length: intvision_tokens: listvision_token_shapes: listvision_sequence_indexes: Tensorvision_mse_loss_indexes: Tensorvision_timesteps: Tensorvision_noisy_frame_indexes: listsound_tokens: list[torch.Tensor] | None = Nonesound_token_shapes: list[tuple[int, int, int]] | None = Nonesound_sequence_indexes: torch.Tensor | None = Nonesound_mse_loss_indexes: torch.Tensor | None = Nonesound_timesteps: torch.Tensor | None = Nonesound_noisy_frame_indexes: list[torch.Tensor] | None = Noneaction_tokens: list[torch.Tensor] | None = Noneaction_token_shapes: list[tuple[int, int, int]] | None = Noneaction_sequence_indexes: torch.Tensor | None = Noneaction_mse_loss_indexes: torch.Tensor | None = Noneaction_timesteps: torch.Tensor | None = Noneaction_noisy_frame_indexes: list[torch.Tensor] | None = Noneaction_domain_ids: list[torch.Tensor] | None = None )

Parameters

  • input_ids — Text token IDs placed at text_indexes in the joint sequence.
  • text_indexes — Indices of text tokens in the joint sequence.
  • position_ids[3, sequence_length] mRoPE position IDs for the full joint sequence.
  • und_len — Length of the causal text (understanding) prefix; generation tokens follow.
  • sequence_length — Total length of the joint packed sequence.
  • vision_tokens — Per-item vision latent tensors before patchify.
  • vision_token_shapes — Patch grid shapes (T, H, W) per vision item.
  • vision_sequence_indexes — Indices of vision tokens in the joint sequence.
  • vision_mse_loss_indexes — Indices used to read vision predictions after the backbone.
  • vision_timesteps — Per-patch diffusion timesteps for vision tokens.
  • vision_noisy_frame_indexes — Noisy frame indices per vision item.
  • sound_tokens — Optional sound latent tensors before packing.
  • sound_token_shapes — Optional patch grid shapes for sound items.
  • sound_sequence_indexes — Optional indices of sound tokens in the joint sequence.
  • sound_mse_loss_indexes — Optional indices used to read sound predictions.
  • sound_timesteps — Optional per-token diffusion timesteps for sound.
  • sound_noisy_frame_indexes — Optional noisy frame indices per sound item.
  • action_tokens — Optional action latent tensors before packing.
  • action_token_shapes — Optional patch grid shapes (T, H, W) per action item.
  • action_sequence_indexes — Optional indices of action tokens in the joint sequence.
  • action_mse_loss_indexes — Optional indices used to read action predictions after the backbone.
  • action_timesteps — Optional per-token diffusion timesteps for action tokens.
  • action_noisy_frame_indexes — Optional noisy frame indices per action item.
  • action_domain_ids — Optional per-item domain IDs selecting the action head weights.

Run a full denoising-step forward pass.

Update on GitHub