Skip to content

vllm.model_executor.models.moonvit

VL_VISION_ATTENTION_FUNCTIONS module-attribute

VL_VISION_ATTENTION_FUNCTIONS = {
    "flash_attention_2": multihead_attention,
    "sdpa": sdpa_attention,
}

Learnable2DInterpPosEmb

Bases: Module

Source code in vllm/model_executor/models/moonvit.py
class Learnable2DInterpPosEmb(nn.Module):

    def __init__(self,
                 height: int,
                 width: int,
                 dim: int,
                 interpolation_mode: str = "bicubic") -> None:
        super().__init__()
        self.height = height
        self.width = width
        self.interpolation_mode = interpolation_mode
        self.weight = nn.Parameter(torch.empty(height, width, dim))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight)

    def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
        pos_embs = []
        for shape in grid_hws.tolist():
            if shape == self.weight.shape[:-1]:
                pos_embs.append(self.weight.flatten(end_dim=1))
            else:
                pos_embs.append(
                    F.interpolate(
                        self.weight.permute((2, 0, 1)).unsqueeze(0),
                        size=shape,
                        mode=self.interpolation_mode,
                    ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1))
        out = x + torch.cat(pos_embs)
        return out

height instance-attribute

height = height

interpolation_mode instance-attribute

interpolation_mode = interpolation_mode

weight instance-attribute

weight = Parameter(empty(height, width, dim))

width instance-attribute

width = width

__init__

__init__(
    height: int,
    width: int,
    dim: int,
    interpolation_mode: str = "bicubic",
) -> None
Source code in vllm/model_executor/models/moonvit.py
def __init__(self,
             height: int,
             width: int,
             dim: int,
             interpolation_mode: str = "bicubic") -> None:
    super().__init__()
    self.height = height
    self.width = width
    self.interpolation_mode = interpolation_mode
    self.weight = nn.Parameter(torch.empty(height, width, dim))
    self.reset_parameters()

forward

forward(x: Tensor, grid_hws: Tensor) -> Tensor
Source code in vllm/model_executor/models/moonvit.py
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
    pos_embs = []
    for shape in grid_hws.tolist():
        if shape == self.weight.shape[:-1]:
            pos_embs.append(self.weight.flatten(end_dim=1))
        else:
            pos_embs.append(
                F.interpolate(
                    self.weight.permute((2, 0, 1)).unsqueeze(0),
                    size=shape,
                    mode=self.interpolation_mode,
                ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1))
    out = x + torch.cat(pos_embs)
    return out

reset_parameters

reset_parameters()
Source code in vllm/model_executor/models/moonvit.py
def reset_parameters(self):
    nn.init.normal_(self.weight)

MLP2

Bases: Module

Parameters:

Name Type Description Default
dims list[int]

[in_dim, hidden_dim, out_dim]

required
bias

whether to use bias in linear layer.

True
Source code in vllm/model_executor/models/moonvit.py
class MLP2(nn.Module):
    """
    Args:
        dims: [in_dim, hidden_dim, out_dim]
        bias: whether to use bias in linear layer.
    """

    def __init__(self, dims: list[int], activation, bias=True):
        super().__init__()
        assert len(dims) == 3
        self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
        self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
        self.activation = activation
        for m in [self.fc0, self.fc1]:
            nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc0(x)
        x = self.activation(x)
        return self.fc1(x)

activation instance-attribute

activation = activation

fc0 instance-attribute

fc0 = Linear(dims[0], dims[1], bias=bias)

fc1 instance-attribute

fc1 = Linear(dims[1], dims[2], bias=bias)

__init__

__init__(dims: list[int], activation, bias=True)
Source code in vllm/model_executor/models/moonvit.py
def __init__(self, dims: list[int], activation, bias=True):
    super().__init__()
    assert len(dims) == 3
    self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
    self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
    self.activation = activation
    for m in [self.fc0, self.fc1]:
        nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
        if m.bias is not None:
            nn.init.zeros_(m.bias)

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/moonvit.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.fc0(x)
    x = self.activation(x)
    return self.fc1(x)

MoonVisionPatchEmbed

Bases: Module

Source code in vllm/model_executor/models/moonvit.py
class MoonVisionPatchEmbed(nn.Module):

    def __init__(
        self,
        out_dim: int,
        in_dim: int = 3,
        patch_size: Union[int, tuple[int, int]] = (14, 14),
        pos_emb_height: int = 14,
        pos_emb_width: int = 14,
    ):
        super().__init__()
        assert isinstance(
            patch_size,
            (int, Sequence)), f"Invalid patch_size type: {type(patch_size)}"
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        assert (len(patch_size) == 2
                ), f"Expected patch_size to be a tuple of 2, got {patch_size}"
        self.patch_size = patch_size

        self.proj = nn.Conv2d(in_dim,
                              out_dim,
                              kernel_size=patch_size,
                              stride=patch_size)

        self.pos_emb = Learnable2DInterpPosEmb(height=pos_emb_height,
                                               width=pos_emb_width,
                                               dim=out_dim)

    def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (L, Channels): input tensor
            grid_hw (N, 2): grid height and width

        Returns:
            (L, Cout) tensor
        """
        x = self.proj(x).view(x.size(0), -1)
        # apply positional embedding
        x = self.pos_emb(x, grid_hw)
        return x

patch_size instance-attribute

patch_size = patch_size

pos_emb instance-attribute

pos_emb = Learnable2DInterpPosEmb(
    height=pos_emb_height, width=pos_emb_width, dim=out_dim
)

proj instance-attribute

proj = Conv2d(
    in_dim,
    out_dim,
    kernel_size=patch_size,
    stride=patch_size,
)

__init__

__init__(
    out_dim: int,
    in_dim: int = 3,
    patch_size: Union[int, tuple[int, int]] = (14, 14),
    pos_emb_height: int = 14,
    pos_emb_width: int = 14,
)
Source code in vllm/model_executor/models/moonvit.py
def __init__(
    self,
    out_dim: int,
    in_dim: int = 3,
    patch_size: Union[int, tuple[int, int]] = (14, 14),
    pos_emb_height: int = 14,
    pos_emb_width: int = 14,
):
    super().__init__()
    assert isinstance(
        patch_size,
        (int, Sequence)), f"Invalid patch_size type: {type(patch_size)}"
    if isinstance(patch_size, int):
        patch_size = (patch_size, patch_size)
    assert (len(patch_size) == 2
            ), f"Expected patch_size to be a tuple of 2, got {patch_size}"
    self.patch_size = patch_size

    self.proj = nn.Conv2d(in_dim,
                          out_dim,
                          kernel_size=patch_size,
                          stride=patch_size)

    self.pos_emb = Learnable2DInterpPosEmb(height=pos_emb_height,
                                           width=pos_emb_width,
                                           dim=out_dim)

forward

forward(x: Tensor, grid_hw: Tensor) -> Tensor

Parameters:

Name Type Description Default
x (L, Channels)

input tensor

required
grid_hw (N, 2)

grid height and width

required

Returns:

Type Description
Tensor

(L, Cout) tensor

Source code in vllm/model_executor/models/moonvit.py
def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
    """
    Args:
        x (L, Channels): input tensor
        grid_hw (N, 2): grid height and width

    Returns:
        (L, Cout) tensor
    """
    x = self.proj(x).view(x.size(0), -1)
    # apply positional embedding
    x = self.pos_emb(x, grid_hw)
    return x

MoonVitEncoder

Bases: Module

Source code in vllm/model_executor/models/moonvit.py
class MoonVitEncoder(nn.Module):

    def __init__(
        self,
        hidden_dim: int,
        num_layers: int,
        block_cfg: dict,
    ) -> None:
        super().__init__()

        self.rope_2d = Rope2DPosEmb(
            block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512)
        self.blocks = nn.ModuleList(
            [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)])
        self.final_layernorm = nn.LayerNorm(hidden_dim)

    def forward(self, hidden_states: torch.Tensor,
                grid_hw: torch.Tensor) -> torch.Tensor:
        rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(
            grid_hws=grid_hw)

        lengths = torch.cat((
            torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
            grid_hw[:, 0] * grid_hw[:, 1],
        ))
        cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)

        for _, block in enumerate(self.blocks):
            hidden_states = block(hidden_states,
                                  cu_seqlens,
                                  rope_freqs_cis=rope_freqs_cis)

        hidden_states = self.final_layernorm(hidden_states)

        return hidden_states

blocks instance-attribute

blocks = ModuleList(
    [
        MoonVitEncoderLayer(**block_cfg)
        for _ in range(num_layers)
    ]
)

final_layernorm instance-attribute

final_layernorm = LayerNorm(hidden_dim)

rope_2d instance-attribute

rope_2d = Rope2DPosEmb(
    block_cfg["hidden_dim"] // block_cfg["num_heads"],
    512,
    512,
)

__init__

__init__(
    hidden_dim: int, num_layers: int, block_cfg: dict
) -> None
Source code in vllm/model_executor/models/moonvit.py
def __init__(
    self,
    hidden_dim: int,
    num_layers: int,
    block_cfg: dict,
) -> None:
    super().__init__()

    self.rope_2d = Rope2DPosEmb(
        block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512)
    self.blocks = nn.ModuleList(
        [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)])
    self.final_layernorm = nn.LayerNorm(hidden_dim)

forward

forward(hidden_states: Tensor, grid_hw: Tensor) -> Tensor
Source code in vllm/model_executor/models/moonvit.py
def forward(self, hidden_states: torch.Tensor,
            grid_hw: torch.Tensor) -> torch.Tensor:
    rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(
        grid_hws=grid_hw)

    lengths = torch.cat((
        torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
        grid_hw[:, 0] * grid_hw[:, 1],
    ))
    cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)

    for _, block in enumerate(self.blocks):
        hidden_states = block(hidden_states,
                              cu_seqlens,
                              rope_freqs_cis=rope_freqs_cis)

    hidden_states = self.final_layernorm(hidden_states)

    return hidden_states

MoonVitEncoderLayer

Bases: Module

Source code in vllm/model_executor/models/moonvit.py
class MoonVitEncoderLayer(nn.Module):

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        *,
        attn_implementation: str = "sdpa",
        activation=F.gelu,
        attn_bias: bool = False,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
        self.attn_implementation = attn_implementation
        # use fa2 in vllm by default
        if is_flash_attn_2_available():
            self.attn_implementation = "flash_attention_2"

        self.norm0 = nn.LayerNorm(hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
        self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
        self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)

    def attention_qkvpacked(
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rope_freqs_cis: Optional[torch.Tensor] = None,
    ):
        """
        Args:
            x (torch.Tensor): (batch_size, seqlen, hidden_dim)
            cu_seqlens (torch.Tensor):
        """
        xqkv = self.wqkv(x)

        qkv_shape = xqkv.size()[:-1] + (
            3,
            self.num_heads,
            self.hidden_size_per_attention_head,
        )
        # xqkv: (batch_size, seqlen, 3, nheads, headdim)
        xqkv = xqkv.view(*qkv_shape)
        xq, xk, xv = torch.unbind(xqkv, dim=-3)

        xq, xk = apply_rope(xq, xk, rope_freqs_cis)

        attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
        attn_out = attn_func(xq,
                             xk,
                             xv,
                             q_cu_seqlens=cu_seqlens,
                             k_cu_seqlens=cu_seqlens)

        attn_out = self.wo(attn_out)
        return attn_out

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rope_freqs_cis: Union[torch.Tensor, None] = None,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set

        Returns:
            output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
        """
        residual = hidden_states
        hidden_states = self.norm0(hidden_states)
        attn_out = self.attention_qkvpacked(hidden_states,
                                            cu_seqlens,
                                            rope_freqs_cis=rope_freqs_cis)
        hidden_states = residual + attn_out

        residual = hidden_states
        hidden_states = self.mlp(self.norm1(hidden_states))
        hidden_states = residual + hidden_states
        return hidden_states

attn_implementation instance-attribute

attn_implementation = attn_implementation

hidden_dim instance-attribute

hidden_dim = hidden_dim

hidden_size_per_attention_head instance-attribute

hidden_size_per_attention_head = hidden_dim // num_heads

mlp instance-attribute

mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)

norm0 instance-attribute

norm0 = LayerNorm(hidden_dim)

norm1 instance-attribute

norm1 = LayerNorm(hidden_dim)

num_heads instance-attribute

num_heads = num_heads

wo instance-attribute

wo = Linear(hidden_dim, hidden_dim, bias=attn_bias)

wqkv instance-attribute

wqkv = Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)

__init__

__init__(
    num_heads: int,
    hidden_dim: int,
    mlp_dim: int,
    *,
    attn_implementation: str = "sdpa",
    activation=gelu,
    attn_bias: bool = False,
)
Source code in vllm/model_executor/models/moonvit.py
def __init__(
    self,
    num_heads: int,
    hidden_dim: int,
    mlp_dim: int,
    *,
    attn_implementation: str = "sdpa",
    activation=F.gelu,
    attn_bias: bool = False,
):
    super().__init__()
    self.num_heads = num_heads
    self.hidden_dim = hidden_dim
    self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
    self.attn_implementation = attn_implementation
    # use fa2 in vllm by default
    if is_flash_attn_2_available():
        self.attn_implementation = "flash_attention_2"

    self.norm0 = nn.LayerNorm(hidden_dim)
    self.norm1 = nn.LayerNorm(hidden_dim)
    self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
    self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
    self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)

attention_qkvpacked

attention_qkvpacked(
    x: Tensor,
    cu_seqlens: Tensor,
    rope_freqs_cis: Optional[Tensor] = None,
)

Parameters:

Name Type Description Default
x Tensor

(batch_size, seqlen, hidden_dim)

required
cu_seqlens Tensor
required
Source code in vllm/model_executor/models/moonvit.py
def attention_qkvpacked(
    self,
    x: torch.Tensor,
    cu_seqlens: torch.Tensor,
    rope_freqs_cis: Optional[torch.Tensor] = None,
):
    """
    Args:
        x (torch.Tensor): (batch_size, seqlen, hidden_dim)
        cu_seqlens (torch.Tensor):
    """
    xqkv = self.wqkv(x)

    qkv_shape = xqkv.size()[:-1] + (
        3,
        self.num_heads,
        self.hidden_size_per_attention_head,
    )
    # xqkv: (batch_size, seqlen, 3, nheads, headdim)
    xqkv = xqkv.view(*qkv_shape)
    xq, xk, xv = torch.unbind(xqkv, dim=-3)

    xq, xk = apply_rope(xq, xk, rope_freqs_cis)

    attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
    attn_out = attn_func(xq,
                         xk,
                         xv,
                         q_cu_seqlens=cu_seqlens,
                         k_cu_seqlens=cu_seqlens)

    attn_out = self.wo(attn_out)
    return attn_out

forward

forward(
    hidden_states: Tensor,
    cu_seqlens: Tensor,
    rope_freqs_cis: Union[Tensor, None] = None,
) -> Tensor

Parameters:

Name Type Description Default
hidden_states Tensor

non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set

required

Returns:

Name Type Description
output Tensor

same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input

Source code in vllm/model_executor/models/moonvit.py
def forward(
    self,
    hidden_states: torch.Tensor,
    cu_seqlens: torch.Tensor,
    rope_freqs_cis: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
    """
    Args:
        hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set

    Returns:
        output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
    """
    residual = hidden_states
    hidden_states = self.norm0(hidden_states)
    attn_out = self.attention_qkvpacked(hidden_states,
                                        cu_seqlens,
                                        rope_freqs_cis=rope_freqs_cis)
    hidden_states = residual + attn_out

    residual = hidden_states
    hidden_states = self.mlp(self.norm1(hidden_states))
    hidden_states = residual + hidden_states
    return hidden_states

MoonVitPretrainedModel

Bases: PreTrainedModel

Source code in vllm/model_executor/models/moonvit.py
class MoonVitPretrainedModel(PreTrainedModel):
    config_class = MoonViTConfig
    model_type = "moonvit"
    _no_split_modules = ["PackingTransformer"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True

    def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        config = deepcopy(config)
        self.merge_kernel_size = config.merge_kernel_size
        self.patch_size = config.patch_size
        self.patch_embed = MoonVisionPatchEmbed(
            out_dim=config.hidden_size,
            patch_size=config.patch_size,
            pos_emb_height=config.init_pos_emb_height,
            pos_emb_width=config.init_pos_emb_width,
        )

        self.encoder = MoonVitEncoder(
            hidden_dim=config.hidden_size,
            num_layers=config.num_hidden_layers,
            block_cfg={
                "num_heads": config.num_attention_heads,
                "hidden_dim": config.hidden_size,
                "mlp_dim": config.intermediate_size,
                "activation": PytorchGELUTanh(),
                "attn_bias": True,
                "attn_implementation": config._attn_implementation,
            },
        )

    def forward(self, pixel_values: torch.Tensor,
                grid_hw: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pixel_values (torch.Tensor): The input pixel values.
            grid_hw (torch.Tensor): The grid height and width.

        Returns:
            torch.Tensor: The output tokens.
        """
        hidden_states = self.patch_embed(pixel_values, grid_hw)
        hidden_states = self.encoder(hidden_states, grid_hw)
        hidden_states = patch_merger(hidden_states,
                                     grid_hw,
                                     merge_kernel_size=self.merge_kernel_size)
        return hidden_states

_no_split_modules class-attribute instance-attribute

_no_split_modules = ['PackingTransformer']

_supports_flash_attn_2 class-attribute instance-attribute

_supports_flash_attn_2 = True

_supports_sdpa class-attribute instance-attribute

_supports_sdpa = True

config_class class-attribute instance-attribute

config_class = MoonViTConfig

encoder instance-attribute

encoder = MoonVitEncoder(
    hidden_dim=hidden_size,
    num_layers=num_hidden_layers,
    block_cfg={
        "num_heads": num_attention_heads,
        "hidden_dim": hidden_size,
        "mlp_dim": intermediate_size,
        "activation": PytorchGELUTanh(),
        "attn_bias": True,
        "attn_implementation": _attn_implementation,
    },
)

merge_kernel_size instance-attribute

merge_kernel_size = merge_kernel_size

model_type class-attribute instance-attribute

model_type = 'moonvit'

patch_embed instance-attribute

patch_embed = MoonVisionPatchEmbed(
    out_dim=hidden_size,
    patch_size=patch_size,
    pos_emb_height=init_pos_emb_height,
    pos_emb_width=init_pos_emb_width,
)

patch_size instance-attribute

patch_size = patch_size

__init__

__init__(config: MoonViTConfig, *inputs, **kwargs)
Source code in vllm/model_executor/models/moonvit.py
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
    super().__init__(config, *inputs, **kwargs)
    config = deepcopy(config)
    self.merge_kernel_size = config.merge_kernel_size
    self.patch_size = config.patch_size
    self.patch_embed = MoonVisionPatchEmbed(
        out_dim=config.hidden_size,
        patch_size=config.patch_size,
        pos_emb_height=config.init_pos_emb_height,
        pos_emb_width=config.init_pos_emb_width,
    )

    self.encoder = MoonVitEncoder(
        hidden_dim=config.hidden_size,
        num_layers=config.num_hidden_layers,
        block_cfg={
            "num_heads": config.num_attention_heads,
            "hidden_dim": config.hidden_size,
            "mlp_dim": config.intermediate_size,
            "activation": PytorchGELUTanh(),
            "attn_bias": True,
            "attn_implementation": config._attn_implementation,
        },
    )

forward

forward(pixel_values: Tensor, grid_hw: Tensor) -> Tensor

Parameters:

Name Type Description Default
pixel_values Tensor

The input pixel values.

required
grid_hw Tensor

The grid height and width.

required

Returns:

Type Description
Tensor

torch.Tensor: The output tokens.

Source code in vllm/model_executor/models/moonvit.py
def forward(self, pixel_values: torch.Tensor,
            grid_hw: torch.Tensor) -> torch.Tensor:
    """
    Args:
        pixel_values (torch.Tensor): The input pixel values.
        grid_hw (torch.Tensor): The grid height and width.

    Returns:
        torch.Tensor: The output tokens.
    """
    hidden_states = self.patch_embed(pixel_values, grid_hw)
    hidden_states = self.encoder(hidden_states, grid_hw)
    hidden_states = patch_merger(hidden_states,
                                 grid_hw,
                                 merge_kernel_size=self.merge_kernel_size)
    return hidden_states

MoonVitVLProjector

Bases: Module

Source code in vllm/model_executor/models/moonvit.py
class MoonVitVLProjector(nn.Module):

    def __init__(
        self,
        in_channels: int,
        merge_kernel_size: list[int, int],
        hidden_act: str = "gelu",
        ln_eps: float = 1e-5,
        out_dim: int = 4096,
    ):
        super().__init__()
        self.hidden_size = in_channels * merge_kernel_size[
            0] * merge_kernel_size[1]

        self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
        self.linear_1 = nn.Linear(self.hidden_size,
                                  self.hidden_size,
                                  bias=True)
        self.act = ACT2FN[hidden_act]
        self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states

act instance-attribute

act = ACT2FN[hidden_act]

hidden_size instance-attribute

hidden_size = (
    in_channels
    * merge_kernel_size[0]
    * merge_kernel_size[1]
)

linear_1 instance-attribute

linear_1 = Linear(hidden_size, hidden_size, bias=True)

linear_2 instance-attribute

linear_2 = Linear(hidden_size, out_dim, bias=True)

pre_norm instance-attribute

pre_norm = LayerNorm(in_channels, eps=ln_eps)

__init__

__init__(
    in_channels: int,
    merge_kernel_size: list[int, int],
    hidden_act: str = "gelu",
    ln_eps: float = 1e-05,
    out_dim: int = 4096,
)
Source code in vllm/model_executor/models/moonvit.py
def __init__(
    self,
    in_channels: int,
    merge_kernel_size: list[int, int],
    hidden_act: str = "gelu",
    ln_eps: float = 1e-5,
    out_dim: int = 4096,
):
    super().__init__()
    self.hidden_size = in_channels * merge_kernel_size[
        0] * merge_kernel_size[1]

    self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
    self.linear_1 = nn.Linear(self.hidden_size,
                              self.hidden_size,
                              bias=True)
    self.act = ACT2FN[hidden_act]
    self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/moonvit.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
    hidden_states = self.linear_1(hidden_states)
    hidden_states = self.act(hidden_states)
    hidden_states = self.linear_2(hidden_states)
    return hidden_states

Rope2DPosEmb

Bases: Module

2D rotary position embedding with multi-resolution support.

This class is intended to be used in the following way: 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis. 2. Before each forward pass, call get_freqs_cis_by_* to get the freqs_cis tensor for this iteration. 3. During the forward pass, pass the freqs_cis tensor to each attention layer, and call apply just before each attention operation. The rope is shared across all attention layers and all heads.

Refs: - RoFormer: https://arxiv.org/abs/2104.09864 - VisionLLaMA: https://arxiv.org/abs/2403.00522 - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py

Parameters:

Name Type Description Default
dim int

usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)

required
max_height int

the maximum height of the 2D grid

required
max_width int

the maximum width of the 2D grid

required
theta_base float

the base of the theta

10000
device str

the device to store the precomputed cis

'cuda'
Source code in vllm/model_executor/models/moonvit.py
class Rope2DPosEmb(nn.Module):
    """2D rotary position embedding with multi-resolution support.

    This class is intended to be used in the following way:
    1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
    2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
    3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
        The rope is shared across all attention layers and all heads.

    Refs:
    - RoFormer: https://arxiv.org/abs/2104.09864
    - VisionLLaMA: https://arxiv.org/abs/2403.00522
    - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py

    Args:
        dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
        max_height (int): the maximum height of the 2D grid
        max_width (int): the maximum width of the 2D grid
        theta_base (float): the base of the theta
        device (str): the device to store the precomputed cis
    """

    def __init__(self,
                 dim: int,
                 max_height: int,
                 max_width: int,
                 theta_base=10000,
                 device="cuda"):
        super().__init__()
        self.dim = dim
        assert self.dim % 4 == 0, "dim must be divisible by 4"
        self.max_height = max_height
        self.max_width = max_width
        self.theta_base = theta_base
        self.device = device

    def extra_repr(self):
        return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"

    @cached_property
    def precomputed_freqs_cis(self) -> torch.Tensor:
        """Calculate the cis(freqs) for each position in the 2D grid.

        Return: complex tensor of shape (max_height, max_width, dim//2) and value:
            height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
            weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim))   with (i in [0, dim//4))
            note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
        """
        N = self.max_height * self.max_width
        flat_pos = torch.arange(0, N).float().to(self.device)
        x_pos = flat_pos % self.max_width
        y_pos = flat_pos // self.max_width
        dim_range = (torch.arange(0, self.dim,
                                  4)[:(self.dim // 4)].float().to(self.device)
                     )  # C/4
        freqs = 1.0 / (self.theta_base**(dim_range / self.dim))
        x_freqs = torch.outer(x_pos, freqs).float()  # N, C/4
        y_freqs = torch.outer(y_pos, freqs).float()  # N, C/4
        x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)  # N, C/4
        y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)  # N, C/4
        # N, C/4, 2
        freqs_cis = torch.cat(
            [x_cis.unsqueeze(dim=-1),
             y_cis.unsqueeze(dim=-1)], dim=-1)
        # max_height, max_width, C/2
        freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
        return freqs_cis

    def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
        """
        Args:
            grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
        Returns:
            freqs_cis: tensor of shape (sum(t * height * width), dim//2)
        """
        shapes = grid_hws.tolist()
        assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width
                   for h, w in shapes), (
                       shapes,
                       self.max_height,
                       self.max_width,
                   )
        freqs_cis = torch.cat(
            [
                self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
                for h, w in shapes
            ],
            dim=0,
        )
        return freqs_cis

    def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor,
                             pos_idx_mask: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
            pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
                Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
        Return:
            freqs_cis: tensor of shape (..., dim//2)
        """
        assert (pos_idx.shape[:-1] == pos_idx_mask.shape
                and pos_idx.shape[-1] == 2 and pos_idx.ndim
                == pos_idx_mask.ndim + 1), (pos_idx.shape, pos_idx_mask.shape)
        assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype

        shp = pos_idx_mask.shape + (self.dim // 2, )  # ..., head_dim/2
        freqs_cis = torch.ones(shp, dtype=torch.complex64,
                               device=self.device)  # ..., head_dim/2
        freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[pos_idx[
            ..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]]
        return freqs_cis

device instance-attribute

device = device

dim instance-attribute

dim = dim

max_height instance-attribute

max_height = max_height

max_width instance-attribute

max_width = max_width

precomputed_freqs_cis cached property

precomputed_freqs_cis: Tensor

Calculate the cis(freqs) for each position in the 2D grid.

complex tensor of shape (max_height, max_width, dim//2) and value:

height axis: ret[h, w, 2i] = cis(h * theta_base(-4i/dim)) weight axis: ret[h, w, 2i+1] = cis(w * theta_base(-4i/dim)) with (i in [0, dim//4)) note: cis is a mathematical notation defined by cis x = cos x + i sin x,

theta_base instance-attribute

theta_base = theta_base

__init__

__init__(
    dim: int,
    max_height: int,
    max_width: int,
    theta_base=10000,
    device="cuda",
)
Source code in vllm/model_executor/models/moonvit.py
def __init__(self,
             dim: int,
             max_height: int,
             max_width: int,
             theta_base=10000,
             device="cuda"):
    super().__init__()
    self.dim = dim
    assert self.dim % 4 == 0, "dim must be divisible by 4"
    self.max_height = max_height
    self.max_width = max_width
    self.theta_base = theta_base
    self.device = device

extra_repr

extra_repr()
Source code in vllm/model_executor/models/moonvit.py
def extra_repr(self):
    return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"

get_freqs_cis_by_idx

get_freqs_cis_by_idx(
    pos_idx: Tensor, pos_idx_mask: Tensor
) -> Tensor

Parameters:

Name Type Description Default
pos_idx Tensor

tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.

required
pos_idx_mask Tensor

a mask of shape (...), the leading dimensions should be the same as pos_idx. Rope will only be applied to the tokens with True mask. freqs_cis for the tokens with False mask with be ones.

required

Return: freqs_cis: tensor of shape (..., dim//2)

Source code in vllm/model_executor/models/moonvit.py
def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor,
                         pos_idx_mask: torch.Tensor) -> torch.Tensor:
    """
    Args:
        pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
        pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
            Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
    Return:
        freqs_cis: tensor of shape (..., dim//2)
    """
    assert (pos_idx.shape[:-1] == pos_idx_mask.shape
            and pos_idx.shape[-1] == 2 and pos_idx.ndim
            == pos_idx_mask.ndim + 1), (pos_idx.shape, pos_idx_mask.shape)
    assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype

    shp = pos_idx_mask.shape + (self.dim // 2, )  # ..., head_dim/2
    freqs_cis = torch.ones(shp, dtype=torch.complex64,
                           device=self.device)  # ..., head_dim/2
    freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[pos_idx[
        ..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]]
    return freqs_cis

get_freqs_cis_by_seqlens

get_freqs_cis_by_seqlens(grid_hws: Tensor) -> Tensor

Parameters:

Name Type Description Default
grid_hws Tensor

containing list of (height, width) or (t, height, width) tuples.

required

Returns: freqs_cis: tensor of shape (sum(t * height * width), dim//2)

Source code in vllm/model_executor/models/moonvit.py
def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
    """
    Args:
        grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
    Returns:
        freqs_cis: tensor of shape (sum(t * height * width), dim//2)
    """
    shapes = grid_hws.tolist()
    assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width
               for h, w in shapes), (
                   shapes,
                   self.max_height,
                   self.max_width,
               )
    freqs_cis = torch.cat(
        [
            self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
            for h, w in shapes
        ],
        dim=0,
    )
    return freqs_cis

_apply_rope_input_validation

_apply_rope_input_validation(x, freqs_cis)
Source code in vllm/model_executor/models/moonvit.py
def _apply_rope_input_validation(x, freqs_cis):
    assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
    assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
    assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
    assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype

apply_rope

apply_rope(
    xq: Tensor, xk: Tensor, freqs_cis: Tensor
) -> tuple[Tensor, Tensor]

(The leading dimensions of all inputs should be the same)

Name Type Description Default
xq Tensor

query, tensor of shape (..., num_heads, head_dim)

required
xk Tensor

key, tensor of shape (..., num_heads, head_dim)

required
freqs_cis Tensor

tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.

required

Returns: xq_out, xk_out: tensors of shape (..., num_heads, head_dim)

Source code in vllm/model_executor/models/moonvit.py
def apply_rope(xq: torch.Tensor, xk: torch.Tensor,
               freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Args: (The leading dimensions of all inputs should be the same)
        xq: query, tensor of shape (..., num_heads, head_dim)
        xk: key, tensor of shape (..., num_heads, head_dim)
        freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
    Returns:
        xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
    """
    _apply_rope_input_validation(xq, freqs_cis)
    _apply_rope_input_validation(xk, freqs_cis)

    freqs_cis = freqs_cis.unsqueeze(-2)  # ..., 1, head_dim/2
    # ..., num_heads, head_dim/2
    xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(
        -2)  # ..., num_heads, head_dim
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(
        -2)  # ..., num_heads, head_dim
    return xq_out.type_as(xq), xk_out.type_as(xk)

multihead_attention

multihead_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    q_cu_seqlens: Optional[Tensor] = None,
    k_cu_seqlens: Optional[Tensor] = None,
)

Multi-head attention using flash attention 2.

Parameters:

Name Type Description Default
q, (k, v)

tensor of shape (batch_size, seqlen, num_heads, head_dim), or (tot_seqlens, num_heads, head_dim) if packing.

required
q_cu_seqlens Tensor

cumulative sequence lengths of q. The first element should be 0 and the last element should be q.shape[0].

None
k_cu_seqlens Tensor

cumulative sequence lengths of k. The first element should be 0 and the last element should be k.shape[0].

None

Returns:

Name Type Description
output

shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing, where dim = num_heads * head_dim

Source code in vllm/model_executor/models/moonvit.py
def multihead_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    q_cu_seqlens: Optional[torch.Tensor] = None,
    k_cu_seqlens: Optional[torch.Tensor] = None,
):
    """Multi-head attention using flash attention 2.

    Args:
        q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
            or (tot_seqlens, num_heads, head_dim) if packing.
        q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
            The first element should be 0 and the last element should be q.shape[0].
        k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
            The first element should be 0 and the last element should be k.shape[0].

    Returns:
        output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
            where dim = num_heads * head_dim
    """
    # Unified format legal check
    assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
    assert q_cu_seqlens[-1] == q.shape[
        0], "q_cu_seqlens must sum to q.shape[0]"
    assert (k_cu_seqlens[-1] == k.shape[0] ==
            v.shape[0]), "k_cu_seqlens must sum to k.shape[0]"
    assert q.dtype in [
        torch.bfloat16,
        torch.float16,
    ], f"unsupported dtype {q.dtype} for multihead attn"

    max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
    max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
    attn_out = flash_attn_varlen_func(
        q,
        k,
        v,
        q_cu_seqlens,
        k_cu_seqlens,
        max_seqlen_q,
        max_seqlen_k,
        causal=False,
    )
    attn_out = attn_out.flatten(start_dim=-2)

    return attn_out

patch_merger

patch_merger(
    x: Tensor,
    grid_hw: Tensor,
    merge_kernel_size: list[int, int] = (2, 2),
) -> list[Tensor]
Source code in vllm/model_executor/models/moonvit.py
def patch_merger(
        x: torch.Tensor,
        grid_hw: torch.Tensor,
        merge_kernel_size: list[int, int] = (2, 2),
) -> list[torch.Tensor]:
    d_model = x.size(-1)

    outputs = []
    pre_sum = 0
    for x_shape in grid_hw.tolist():
        height, width = x_shape[0], x_shape[1]
        # Get the current sequence
        seq = x[pre_sum:pre_sum + height * width]
        # Reshape along self.merge_kernel_size and concat to the last dimension
        kernel_height, kernel_width = merge_kernel_size
        new_height, new_width = height // kernel_height, width // kernel_width
        reshaped_seq = seq.view(new_height, kernel_height, new_width,
                                kernel_width, d_model)
        reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
        padded_seq = reshaped_seq.view(new_height * new_width,
                                       kernel_height * kernel_width, -1)
        outputs.append(padded_seq)
        pre_sum += height * width

    return outputs

sdpa_attention

sdpa_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    q_cu_seqlens: Optional[Tensor] = None,
    k_cu_seqlens: Optional[Tensor] = None,
) -> Tensor

SDPA attention.

Parameters:

Name Type Description Default
q, (k, v)

tensor of shape (batch_size, seqlen, num_heads, head_dim), or (tot_seqlens, num_heads, head_dim) if packing.

required
Source code in vllm/model_executor/models/moonvit.py
def sdpa_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    q_cu_seqlens: Optional[torch.Tensor] = None,
    k_cu_seqlens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """SDPA attention.

    Args:
        q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
            or (tot_seqlens, num_heads, head_dim) if packing.
    """
    seq_length = q.shape[0]
    attention_mask = torch.zeros([1, seq_length, seq_length],
                                 device=q.device,
                                 dtype=torch.bool)
    for i in range(1, len(q_cu_seqlens)):
        attention_mask[
            ...,
            q_cu_seqlens[i - 1]:q_cu_seqlens[i],
            q_cu_seqlens[i - 1]:q_cu_seqlens[i],
        ] = True
    q = q.transpose(0, 1)
    k = k.transpose(0, 1)
    v = v.transpose(0, 1)
    attn_output = F.scaled_dot_product_attention(q,
                                                 k,
                                                 v,
                                                 attention_mask,
                                                 dropout_p=0.0)
    attn_output = attn_output.transpose(0, 1)
    attn_output = attn_output.reshape(seq_length, -1)
    return attn_output