Skip to content

vllm.model_executor.layers.rotary_embedding

Rotary Positional Embeddings.

_ROPE_DICT module-attribute

_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}

DeepseekScalingRotaryEmbedding

Bases: RotaryEmbedding

RotaryEmbedding extended with YaRN method.

Credits to Peng et al. github.com/jquesnelle/yarn

Source code in vllm/model_executor/layers/rotary_embedding.py
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with YaRN method.

    Credits to Peng et al. github.com/jquesnelle/yarn
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        scaling_factor: float,
        dtype: torch.dtype,
        *,
        extrapolation_factor: float = 1,
        attn_factor: float = 1,
        beta_fast: int = 32,
        beta_slow: int = 1,
        mscale: float = 1,
        mscale_all_dim: float = 0,
    ) -> None:
        self.scaling_factor = scaling_factor
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        # Get n-d magnitude scaling corrected for interpolation.
        self.mscale = float(
            yarn_get_mscale(self.scaling_factor, float(mscale)) /
            yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
            attn_factor)
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
        pos_freqs = self.base**(
            torch.arange(0,
                         self.rotary_dim,
                         2,
                         dtype=torch.float,
                         device=current_platform.device_type) /
            self.rotary_dim)
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

        low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
                                                self.rotary_dim, self.base,
                                                self.max_position_embeddings)
        # Get n-d rotational scaling corrected for extrapolation
        inv_freq_mask = (1 - _yarn_linear_ramp_mask(
            low, high, self.rotary_dim // 2,
            dtype=torch.float)) * self.extrapolation_factor
        inv_freq = inv_freq_interpolation * (
            1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.scaling_factor)
        t = torch.arange(self.max_position_embeddings * self.scaling_factor,
                         device=current_platform.device_type,
                         dtype=torch.float32)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = (freqs.cos() * self.mscale)
        sin = (freqs.sin() * self.mscale)
        cache = torch.cat((cos, sin), dim=-1)
        return cache

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        assert key is not None
        query_rot = query[..., :self.rotary_dim]
        key_rot = key[..., :self.rotary_dim]
        if self.rotary_dim < self.head_size:
            query_pass = query[..., self.rotary_dim:]
            key_pass = key[..., self.rotary_dim:]

        if self.cos_sin_cache.device != positions.device:
            self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
                positions.device)
        cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
                                     if offsets is not None else positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        if self.is_neox_style:
            # NOTE(woosuk): Here we assume that the positions tensor has the
            # shape [batch_size, seq_len].
            cos = cos.repeat(1, 1, 2).unsqueeze(-2)
            sin = sin.repeat(1, 1, 2).unsqueeze(-2)
        else:
            cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
            sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

        rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
        query_rot = query_rot * cos + rotate_fn(query_rot) * sin
        key_rot = key_rot * cos + rotate_fn(key_rot) * sin

        if self.rotary_dim < self.head_size:
            query = torch.cat((query_rot, query_pass), dim=-1)
            key = torch.cat((key_rot, key_pass), dim=-1)
        else:
            query = query_rot
            key = key_rot
        return query, key

attn_factor instance-attribute

attn_factor = attn_factor

beta_fast instance-attribute

beta_fast = beta_fast

beta_slow instance-attribute

beta_slow = beta_slow

extrapolation_factor instance-attribute

extrapolation_factor = extrapolation_factor

mscale instance-attribute

mscale = float(
    yarn_get_mscale(scaling_factor, float(mscale))
    / yarn_get_mscale(scaling_factor, float(mscale_all_dim))
    * attn_factor
)

scaling_factor instance-attribute

scaling_factor = scaling_factor

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: dtype,
    *,
    extrapolation_factor: float = 1,
    attn_factor: float = 1,
    beta_fast: int = 32,
    beta_slow: int = 1,
    mscale: float = 1,
    mscale_all_dim: float = 0,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: torch.dtype,
    *,
    extrapolation_factor: float = 1,
    attn_factor: float = 1,
    beta_fast: int = 32,
    beta_slow: int = 1,
    mscale: float = 1,
    mscale_all_dim: float = 0,
) -> None:
    self.scaling_factor = scaling_factor
    self.extrapolation_factor = extrapolation_factor
    self.attn_factor = attn_factor
    self.beta_fast = beta_fast
    self.beta_slow = beta_slow
    # Get n-d magnitude scaling corrected for interpolation.
    self.mscale = float(
        yarn_get_mscale(self.scaling_factor, float(mscale)) /
        yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
        attn_factor)
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(self.scaling_factor)
    t = torch.arange(self.max_position_embeddings * self.scaling_factor,
                     device=current_platform.device_type,
                     dtype=torch.float32)
    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = (freqs.cos() * self.mscale)
    sin = (freqs.sin() * self.mscale)
    cache = torch.cat((cos, sin), dim=-1)
    return cache

_compute_inv_freq

_compute_inv_freq(scaling_factor: float) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
    pos_freqs = self.base**(
        torch.arange(0,
                     self.rotary_dim,
                     2,
                     dtype=torch.float,
                     device=current_platform.device_type) /
        self.rotary_dim)
    inv_freq_extrapolation = 1.0 / pos_freqs
    inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

    low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
                                            self.rotary_dim, self.base,
                                            self.max_position_embeddings)
    # Get n-d rotational scaling corrected for extrapolation
    inv_freq_mask = (1 - _yarn_linear_ramp_mask(
        low, high, self.rotary_dim // 2,
        dtype=torch.float)) * self.extrapolation_factor
    inv_freq = inv_freq_interpolation * (
        1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
    return inv_freq

forward

forward(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]

PyTorch-native implementation equivalent to forward().

Source code in vllm/model_executor/layers/rotary_embedding.py
def forward(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    """PyTorch-native implementation equivalent to forward()."""
    assert key is not None
    query_rot = query[..., :self.rotary_dim]
    key_rot = key[..., :self.rotary_dim]
    if self.rotary_dim < self.head_size:
        query_pass = query[..., self.rotary_dim:]
        key_pass = key[..., self.rotary_dim:]

    if self.cos_sin_cache.device != positions.device:
        self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
            positions.device)
    cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
                                 if offsets is not None else positions]
    cos, sin = cos_sin.chunk(2, dim=-1)
    if self.is_neox_style:
        # NOTE(woosuk): Here we assume that the positions tensor has the
        # shape [batch_size, seq_len].
        cos = cos.repeat(1, 1, 2).unsqueeze(-2)
        sin = sin.repeat(1, 1, 2).unsqueeze(-2)
    else:
        cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
        sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

    rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
    query_rot = query_rot * cos + rotate_fn(query_rot) * sin
    key_rot = key_rot * cos + rotate_fn(key_rot) * sin

    if self.rotary_dim < self.head_size:
        query = torch.cat((query_rot, query_pass), dim=-1)
        key = torch.cat((key_rot, key_pass), dim=-1)
    else:
        query = query_rot
        key = key_rot
    return query, key

DualChunkRotaryEmbedding

Bases: CustomOp

Rotary positional embedding for Dual Chunk Attention.

Source code in vllm/model_executor/layers/rotary_embedding.py
@CustomOp.register("dual_chunk_rotary_embedding")
class DualChunkRotaryEmbedding(CustomOp):
    """Rotary positional embedding for Dual Chunk Attention."""

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
        chunk_size: int,
        local_size: int,
    ) -> None:
        super().__init__()
        self.head_size = head_size
        self.rotary_dim = rotary_dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.is_neox_style = is_neox_style
        self.chunk_size = chunk_size
        self.local_size = local_size
        self.dtype = dtype
        self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
        (q_cache, qc_cache, k_cache, qc_no_clamp_cache,
         q_inter_cache) = self._compute_cos_sin_cache()

        self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
        self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
        self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
        self.register_buffer("cos_sin_qc_no_clamp_cache",
                             qc_no_clamp_cache,
                             persistent=False)
        self.register_buffer("cos_sin_q_inter_cache",
                             q_inter_cache,
                             persistent=False)

    def _compute_inv_freq(self, base: float) -> torch.Tensor:
        """Compute the inverse frequency."""
        # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
        # However, we use `torch.arange(..., dtype=torch.float)` instead to
        # avoid numerical issues with large base values (e.g., 10000000).
        # This may cause a slight numerical difference between the HF
        # implementation and ours.
        # NOTE(woosuk): To exactly match the HF implementation, we need to
        # use CPU to compute the cache and then move it to GPU. However, we
        # create the cache on GPU for faster initialization. This may cause
        # a slight numerical difference between the HF implementation and ours.
        inv_freq = 1.0 / (base**(torch.arange(
            0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        """Compute the cos and sin cache."""
        inv_freq = self._compute_inv_freq(self.base)
        chunk_len = self.chunk_size - self.local_size
        q_t = torch.arange(chunk_len, dtype=torch.float)
        qc_t = (torch.arange(chunk_len, dtype=torch.float) +
                chunk_len).clamp(max=self.chunk_size)
        k_t = torch.arange(self.max_position_embeddings,
                           dtype=torch.float) % chunk_len

        # count from chunk_len, no clamp(self.chunk_size) restriction
        qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
        # count from self.chunk_size for q_inter's rope
        q_inter_t = torch.arange(chunk_len,
                                 dtype=torch.float) + self.chunk_size

        q_freqs = torch.outer(q_t, inv_freq)
        qc_freqs = torch.outer(qc_t, inv_freq)
        k_freqs = torch.outer(k_t, inv_freq)
        qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
        q_inter_freqs = torch.outer(q_inter_t, inv_freq)

        q_cos = q_freqs.cos()
        q_sin = q_freqs.sin()
        qc_cos = qc_freqs.cos()
        qc_sin = qc_freqs.sin()
        k_cos = k_freqs.cos()
        k_sin = k_freqs.sin()

        qc_no_clamp_cos = qc_no_clamp_freqs.cos()
        qc_no_clamp_sin = qc_no_clamp_freqs.sin()
        q_inter_cos = q_inter_freqs.cos()
        q_inter_sin = q_inter_freqs.sin()

        q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype,
                                                       device=self.device)
        qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype,
                                                          device=self.device)
        k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype,
                                                       device=self.device)
        qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin),
                                      dim=-1).to(dtype=self.dtype,
                                                 device=self.device)
        q_inter_cache = torch.cat((q_inter_cos, q_inter_sin),
                                  dim=-1).to(dtype=self.dtype,
                                             device=self.device)
        return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        query = query.view(*query.shape[:-1], -1, self.head_size)
        key = key.view(*key.shape[:-1], -1, self.head_size)
        query_rot = query[..., :self.rotary_dim]
        key_rot = key[..., :self.rotary_dim]
        if self.rotary_dim < self.head_size:
            query_pass = query[..., self.rotary_dim:]
            key_pass = key[..., self.rotary_dim:]
        else:
            query_pass = None
            key_pass = None

        positions_with_offsets = (torch.add(positions, offsets)
                                  if offsets is not None else positions)
        key = self._apply_rotary_embedding(
            self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass)
        chunk_len = self.chunk_size - self.local_size
        query = self._apply_rotary_embedding(
            self.cos_sin_q_cache[positions_with_offsets % chunk_len],
            query_rot, query_pass)
        query_succ = self._apply_rotary_embedding(
            self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
            query_rot, query_pass)
        query_inter = self._apply_rotary_embedding(
            self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
            query_rot, query_pass)
        query_succ_critical = self._apply_rotary_embedding(
            self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
            query_rot, query_pass)
        query_inter_critical = self._apply_rotary_embedding(
            self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
            query_rot, query_pass)

        # merge query into one tensor to simplify the interfaces
        query = torch.cat((
            query,
            query_succ,
            query_inter,
            query_succ_critical,
            query_inter_critical,
        ),
                          dim=-1)
        return query, key

    def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
        cos, sin = cos_sin.chunk(2, dim=-1)
        if self.is_neox_style:
            # NOTE(woosuk): Here we assume that the positions tensor has the
            # shape [batch_size, seq_len].
            cos = cos.repeat(1, 1, 2).unsqueeze(-2)
            sin = sin.repeat(1, 1, 2).unsqueeze(-2)
        else:
            cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
            sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
        rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
        hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin

        if self.rotary_dim < self.head_size:
            hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
        else:
            hidden = hidden_rot
        return hidden.flatten(-2).squeeze(0)

    def extra_repr(self) -> str:
        s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
        s += f", max_position_embeddings={self.max_position_embeddings}"
        s += f", base={self.base}, is_neox_style={self.is_neox_style}"
        s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
        return s

base instance-attribute

base = base

chunk_size instance-attribute

chunk_size = chunk_size

device instance-attribute

device = device(f'cuda:{current_device()}')

dtype instance-attribute

dtype = dtype

head_size instance-attribute

head_size = head_size

is_neox_style instance-attribute

is_neox_style = is_neox_style

local_size instance-attribute

local_size = local_size

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

rotary_dim instance-attribute

rotary_dim = rotary_dim

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
    chunk_size: int,
    local_size: int,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
    chunk_size: int,
    local_size: int,
) -> None:
    super().__init__()
    self.head_size = head_size
    self.rotary_dim = rotary_dim
    self.max_position_embeddings = max_position_embeddings
    self.base = base
    self.is_neox_style = is_neox_style
    self.chunk_size = chunk_size
    self.local_size = local_size
    self.dtype = dtype
    self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
    (q_cache, qc_cache, k_cache, qc_no_clamp_cache,
     q_inter_cache) = self._compute_cos_sin_cache()

    self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
    self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
    self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
    self.register_buffer("cos_sin_qc_no_clamp_cache",
                         qc_no_clamp_cache,
                         persistent=False)
    self.register_buffer("cos_sin_q_inter_cache",
                         q_inter_cache,
                         persistent=False)

_apply_rotary_embedding

_apply_rotary_embedding(cos_sin, hidden_rot, hidden_pass)
Source code in vllm/model_executor/layers/rotary_embedding.py
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
    cos, sin = cos_sin.chunk(2, dim=-1)
    if self.is_neox_style:
        # NOTE(woosuk): Here we assume that the positions tensor has the
        # shape [batch_size, seq_len].
        cos = cos.repeat(1, 1, 2).unsqueeze(-2)
        sin = sin.repeat(1, 1, 2).unsqueeze(-2)
    else:
        cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
        sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
    rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
    hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin

    if self.rotary_dim < self.head_size:
        hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
    else:
        hidden = hidden_rot
    return hidden.flatten(-2).squeeze(0)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor

Compute the cos and sin cache.

Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    """Compute the cos and sin cache."""
    inv_freq = self._compute_inv_freq(self.base)
    chunk_len = self.chunk_size - self.local_size
    q_t = torch.arange(chunk_len, dtype=torch.float)
    qc_t = (torch.arange(chunk_len, dtype=torch.float) +
            chunk_len).clamp(max=self.chunk_size)
    k_t = torch.arange(self.max_position_embeddings,
                       dtype=torch.float) % chunk_len

    # count from chunk_len, no clamp(self.chunk_size) restriction
    qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
    # count from self.chunk_size for q_inter's rope
    q_inter_t = torch.arange(chunk_len,
                             dtype=torch.float) + self.chunk_size

    q_freqs = torch.outer(q_t, inv_freq)
    qc_freqs = torch.outer(qc_t, inv_freq)
    k_freqs = torch.outer(k_t, inv_freq)
    qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
    q_inter_freqs = torch.outer(q_inter_t, inv_freq)

    q_cos = q_freqs.cos()
    q_sin = q_freqs.sin()
    qc_cos = qc_freqs.cos()
    qc_sin = qc_freqs.sin()
    k_cos = k_freqs.cos()
    k_sin = k_freqs.sin()

    qc_no_clamp_cos = qc_no_clamp_freqs.cos()
    qc_no_clamp_sin = qc_no_clamp_freqs.sin()
    q_inter_cos = q_inter_freqs.cos()
    q_inter_sin = q_inter_freqs.sin()

    q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype,
                                                   device=self.device)
    qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype,
                                                      device=self.device)
    k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype,
                                                   device=self.device)
    qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin),
                                  dim=-1).to(dtype=self.dtype,
                                             device=self.device)
    q_inter_cache = torch.cat((q_inter_cos, q_inter_sin),
                              dim=-1).to(dtype=self.dtype,
                                         device=self.device)
    return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache

_compute_inv_freq

_compute_inv_freq(base: float) -> Tensor

Compute the inverse frequency.

Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_inv_freq(self, base: float) -> torch.Tensor:
    """Compute the inverse frequency."""
    # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
    # However, we use `torch.arange(..., dtype=torch.float)` instead to
    # avoid numerical issues with large base values (e.g., 10000000).
    # This may cause a slight numerical difference between the HF
    # implementation and ours.
    # NOTE(woosuk): To exactly match the HF implementation, we need to
    # use CPU to compute the cache and then move it to GPU. However, we
    # create the cache on GPU for faster initialization. This may cause
    # a slight numerical difference between the HF implementation and ours.
    inv_freq = 1.0 / (base**(torch.arange(
        0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
    return inv_freq

extra_repr

extra_repr() -> str
Source code in vllm/model_executor/layers/rotary_embedding.py
def extra_repr(self) -> str:
    s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
    s += f", max_position_embeddings={self.max_position_embeddings}"
    s += f", base={self.base}, is_neox_style={self.is_neox_style}"
    s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
    return s

forward

forward(
    positions: Tensor,
    query: Tensor,
    key: Tensor,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/rotary_embedding.py
def forward(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    query = query.view(*query.shape[:-1], -1, self.head_size)
    key = key.view(*key.shape[:-1], -1, self.head_size)
    query_rot = query[..., :self.rotary_dim]
    key_rot = key[..., :self.rotary_dim]
    if self.rotary_dim < self.head_size:
        query_pass = query[..., self.rotary_dim:]
        key_pass = key[..., self.rotary_dim:]
    else:
        query_pass = None
        key_pass = None

    positions_with_offsets = (torch.add(positions, offsets)
                              if offsets is not None else positions)
    key = self._apply_rotary_embedding(
        self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass)
    chunk_len = self.chunk_size - self.local_size
    query = self._apply_rotary_embedding(
        self.cos_sin_q_cache[positions_with_offsets % chunk_len],
        query_rot, query_pass)
    query_succ = self._apply_rotary_embedding(
        self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
        query_rot, query_pass)
    query_inter = self._apply_rotary_embedding(
        self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
        query_rot, query_pass)
    query_succ_critical = self._apply_rotary_embedding(
        self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
        query_rot, query_pass)
    query_inter_critical = self._apply_rotary_embedding(
        self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
        query_rot, query_pass)

    # merge query into one tensor to simplify the interfaces
    query = torch.cat((
        query,
        query_succ,
        query_inter,
        query_succ_critical,
        query_inter_critical,
    ),
                      dim=-1)
    return query, key

DynamicNTKAlphaRotaryEmbedding

Bases: RotaryEmbedding

RotaryEmbedding extended with Dynamic NTK alpha.

Based on the original RotaryEmbedding implementation.

Source code in vllm/model_executor/layers/rotary_embedding.py
class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with Dynamic NTK alpha.

    Based on the original RotaryEmbedding implementation.
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        scaling_alpha: float,
        dtype: torch.dtype,
    ) -> None:
        self.scaling_alpha = scaling_alpha
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        # For Hunyuan DynamicNTKAlphaRotaryEmbedding
        max_len = self.max_position_embeddings
        base = self.base * self.scaling_alpha**(self.rotary_dim /
                                                (self.rotary_dim - 2))
        inv_freq = self._compute_inv_freq(base)
        t = torch.arange(max_len, dtype=torch.float)

        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache

scaling_alpha instance-attribute

scaling_alpha = scaling_alpha

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_alpha: float,
    dtype: dtype,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_alpha: float,
    dtype: torch.dtype,
) -> None:
    self.scaling_alpha = scaling_alpha
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    # For Hunyuan DynamicNTKAlphaRotaryEmbedding
    max_len = self.max_position_embeddings
    base = self.base * self.scaling_alpha**(self.rotary_dim /
                                            (self.rotary_dim - 2))
    inv_freq = self._compute_inv_freq(base)
    t = torch.arange(max_len, dtype=torch.float)

    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = freqs.cos()
    sin = freqs.sin()
    cache = torch.cat((cos, sin), dim=-1)
    return cache

DynamicNTKScalingRotaryEmbedding

Bases: RotaryEmbedding

RotaryEmbedding extended with Dynamic NTK scaling.

Credits to the Reddit users /u/bloc97 and /u/emozilla

Source code in vllm/model_executor/layers/rotary_embedding.py
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with Dynamic NTK scaling.

    Credits to the Reddit users /u/bloc97 and /u/emozilla
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        scaling_factor: float,
        dtype: torch.dtype,
    ) -> None:
        self.scaling_factor = scaling_factor
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        # NOTE(woosuk): self.max_position_embeddings is the original
        # maximum length before applying the rope scaling.
        # Thus, the maximum length after applying the rope scaling is
        # self.max_position_embeddings * self.scaling_factor.
        max_len = self.max_position_embeddings * self.scaling_factor
        base = self.base * (
            (self.scaling_factor * max_len / self.max_position_embeddings) -
            (self.scaling_factor - 1))**(self.rotary_dim /
                                         (self.rotary_dim - 2))
        inv_freq = self._compute_inv_freq(base)
        t = torch.arange(max_len, dtype=torch.float)

        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache

scaling_factor instance-attribute

scaling_factor = scaling_factor

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: dtype,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: torch.dtype,
) -> None:
    self.scaling_factor = scaling_factor
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    # NOTE(woosuk): self.max_position_embeddings is the original
    # maximum length before applying the rope scaling.
    # Thus, the maximum length after applying the rope scaling is
    # self.max_position_embeddings * self.scaling_factor.
    max_len = self.max_position_embeddings * self.scaling_factor
    base = self.base * (
        (self.scaling_factor * max_len / self.max_position_embeddings) -
        (self.scaling_factor - 1))**(self.rotary_dim /
                                     (self.rotary_dim - 2))
    inv_freq = self._compute_inv_freq(base)
    t = torch.arange(max_len, dtype=torch.float)

    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = freqs.cos()
    sin = freqs.sin()
    cache = torch.cat((cos, sin), dim=-1)
    return cache

LinearScalingRotaryEmbedding

Bases: RotaryEmbedding

RotaryEmbedding extended with linear scaling.

It supports multiple scaling factors. Since multiple LoRA adapters may have different scaling factors, we need multiple cos/sin caches. In this way, instead of running rotary embedding kernel per lora, we can run multiple lora in a batched way.

In addition to that, we also keep the cos/sin cache for the scaling factor of 1 (default) at all times.

Exemplary for two scaling factors x=1, y and z with embeddings [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],

we construct the cos/sin cache as follows: [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], ... [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]

We then use offsets to index into the cos/sin cache for the respective scaling factors.

The offset to cache can be accessed via scaling_factor_to_offset API.

Credits to the Reddit user /u/kaiokendev

Source code in vllm/model_executor/layers/rotary_embedding.py
class LinearScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with linear scaling.

    It supports multiple scaling factors. Since multiple LoRA adapters may have
    different scaling factors, we need multiple cos/sin caches. In this way,
    instead of running rotary embedding kernel per lora, we can run multiple
    lora in a batched way.

    In addition to that, we also keep the cos/sin cache for the scaling factor
    of 1 (default) at all times.

    Exemplary for two scaling factors x=1, y and z with embeddings
    [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
    [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
    [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],

    we construct the cos/sin cache as follows:
    [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
        ...
     [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]

    We then use offsets to index into the cos/sin cache for
    the respective scaling factors.

    The offset to cache can be accessed via `scaling_factor_to_offset` API.

    Credits to the Reddit user /u/kaiokendev
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        scaling_factors: Union[list[float], float],
        dtype: torch.dtype,
    ) -> None:
        if isinstance(scaling_factors, float):
            scaling_factors = [scaling_factors]
        self.scaling_factors: list[float] = scaling_factors  # noqa
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)
        # Lazy initialized.
        self._scaling_factor_to_offset: dict[float, int]

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.base)
        cache_list: list[torch.Tensor] = []
        # offsets to the next cache in a tensor.
        # Each offset corresponds to the same index in scaling_factors.
        offsets: list[int] = []
        for scaling_factor in self.scaling_factors:
            # NOTE(woosuk): self.max_position_embeddings is the original
            # maximum length before applying the rope scaling.
            # Thus, the maximum length after applying the rope scaling is
            # self.max_position_embeddings * self.scaling_factor.
            max_len = self.max_position_embeddings * scaling_factor
            t = torch.arange(max_len, dtype=torch.float)
            t = t / scaling_factor

            freqs = torch.einsum("i,j -> ij", t, inv_freq)
            cos = freqs.cos()
            sin = freqs.sin()
            cache = torch.cat((cos, sin), dim=-1)
            if not cache_list:
                offset = 0
            else:
                last_offset = offsets[-1]
                next_max_len = cache_list[-1].shape[0]
                offset = last_offset + next_max_len
            offsets.append(offset)
            cache_list.append(cache)
        self._scaling_factor_to_offset = {
            float(scaling_factor): offsets[i]
            for i, scaling_factor in enumerate(self.scaling_factors)
        }
        assert len(self.scaling_factors) == len(offsets)
        return torch.cat(cache_list, dim=0)

    @property
    def scaling_factor_to_offset(self) -> dict[float, int]:
        return self._scaling_factor_to_offset

_scaling_factor_to_offset instance-attribute

_scaling_factor_to_offset: dict[float, int]

scaling_factor_to_offset property

scaling_factor_to_offset: dict[float, int]

scaling_factors instance-attribute

scaling_factors: list[float] = scaling_factors

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factors: Union[list[float], float],
    dtype: dtype,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factors: Union[list[float], float],
    dtype: torch.dtype,
) -> None:
    if isinstance(scaling_factors, float):
        scaling_factors = [scaling_factors]
    self.scaling_factors: list[float] = scaling_factors  # noqa
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)
    # Lazy initialized.
    self._scaling_factor_to_offset: dict[float, int]

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(self.base)
    cache_list: list[torch.Tensor] = []
    # offsets to the next cache in a tensor.
    # Each offset corresponds to the same index in scaling_factors.
    offsets: list[int] = []
    for scaling_factor in self.scaling_factors:
        # NOTE(woosuk): self.max_position_embeddings is the original
        # maximum length before applying the rope scaling.
        # Thus, the maximum length after applying the rope scaling is
        # self.max_position_embeddings * self.scaling_factor.
        max_len = self.max_position_embeddings * scaling_factor
        t = torch.arange(max_len, dtype=torch.float)
        t = t / scaling_factor

        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        if not cache_list:
            offset = 0
        else:
            last_offset = offsets[-1]
            next_max_len = cache_list[-1].shape[0]
            offset = last_offset + next_max_len
        offsets.append(offset)
        cache_list.append(cache)
    self._scaling_factor_to_offset = {
        float(scaling_factor): offsets[i]
        for i, scaling_factor in enumerate(self.scaling_factors)
    }
    assert len(self.scaling_factors) == len(offsets)
    return torch.cat(cache_list, dim=0)

Llama3RotaryEmbedding

Bases: RotaryEmbedding

Source code in vllm/model_executor/layers/rotary_embedding.py
class Llama3RotaryEmbedding(RotaryEmbedding):

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
        scaling_factor: float,
        low_freq_factor: float,
        high_freq_factor: float,
        orig_max_position: int,
    ) -> None:
        self.scaling_factor = scaling_factor
        self.low_freq_factor = low_freq_factor
        self.high_freq_factor = high_freq_factor
        self.orig_max_position = orig_max_position
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_inv_freq(self, base: float) -> torch.Tensor:
        inv_freqs = super()._compute_inv_freq(base)
        low_freq_wavelen = self.orig_max_position / self.low_freq_factor
        high_freq_wavelen = self.orig_max_position / self.high_freq_factor

        wave_len = 2 * math.pi / inv_freqs
        if self.low_freq_factor != self.high_freq_factor:
            smooth = (self.orig_max_position / wave_len - self.low_freq_factor
                      ) / (self.high_freq_factor - self.low_freq_factor)
        else:
            smooth = 0
        new_freqs = torch.where(
            wave_len < high_freq_wavelen,
            inv_freqs,
            torch.where(
                wave_len > low_freq_wavelen,
                inv_freqs / self.scaling_factor,
                (1 - smooth) * inv_freqs / self.scaling_factor +
                smooth * inv_freqs,
            ),
        )
        return new_freqs

high_freq_factor instance-attribute

high_freq_factor = high_freq_factor

low_freq_factor instance-attribute

low_freq_factor = low_freq_factor

orig_max_position instance-attribute

orig_max_position = orig_max_position

scaling_factor instance-attribute

scaling_factor = scaling_factor

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
    scaling_factor: float,
    low_freq_factor: float,
    high_freq_factor: float,
    orig_max_position: int,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
    scaling_factor: float,
    low_freq_factor: float,
    high_freq_factor: float,
    orig_max_position: int,
) -> None:
    self.scaling_factor = scaling_factor
    self.low_freq_factor = low_freq_factor
    self.high_freq_factor = high_freq_factor
    self.orig_max_position = orig_max_position
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)

_compute_inv_freq

_compute_inv_freq(base: float) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_inv_freq(self, base: float) -> torch.Tensor:
    inv_freqs = super()._compute_inv_freq(base)
    low_freq_wavelen = self.orig_max_position / self.low_freq_factor
    high_freq_wavelen = self.orig_max_position / self.high_freq_factor

    wave_len = 2 * math.pi / inv_freqs
    if self.low_freq_factor != self.high_freq_factor:
        smooth = (self.orig_max_position / wave_len - self.low_freq_factor
                  ) / (self.high_freq_factor - self.low_freq_factor)
    else:
        smooth = 0
    new_freqs = torch.where(
        wave_len < high_freq_wavelen,
        inv_freqs,
        torch.where(
            wave_len > low_freq_wavelen,
            inv_freqs / self.scaling_factor,
            (1 - smooth) * inv_freqs / self.scaling_factor +
            smooth * inv_freqs,
        ),
    )
    return new_freqs

Llama4VisionRotaryEmbedding

Bases: RotaryEmbedding

Source code in vllm/model_executor/layers/rotary_embedding.py
class Llama4VisionRotaryEmbedding(RotaryEmbedding):

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
    ):
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_inv_freq(self, base: float) -> torch.Tensor:
        inv_freqs = super()._compute_inv_freq(base)
        inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
        return inv_freqs

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.base)

        # self.max_position_embeddings here is number of image patches
        # i.e. (image_size // patch_size) ** 2
        num_patches = self.max_position_embeddings
        img_idx = torch.arange(num_patches,
                    dtype=torch.int32) \
                    .reshape(num_patches, 1)
        img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
        img_idx[-1, -1] = -2  # set to ID_CLS_TOKEN
        num_patches_single_dim = int(math.sqrt(num_patches))
        frequencies_x = img_idx % num_patches_single_dim
        frequencies_y = img_idx // num_patches_single_dim
        freqs_x = ((frequencies_x + 1)[..., None] *
                   inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
        freqs_y = ((frequencies_y + 1)[..., None] *
                   inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
        freqs = torch.cat([freqs_x, freqs_y],
                          dim=-1).float().contiguous()[..., ::2]
        freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
        cache = torch.view_as_complex(
            torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
        return cache

    def forward(
        self,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        assert key is not None
        self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
        query_ = torch.view_as_complex(query.float().reshape(
            *query.shape[:-1], -1, 2))
        key_ = torch.view_as_complex(key.float().reshape(
            *key.shape[:-1], -1, 2))
        broadcast_shape = [
            d if i == 1 or i == (query_.ndim - 1) else 1
            for i, d in enumerate(query_.shape)
        ]
        freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
        query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
        key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
        return query_out.type_as(query), key_out.type_as(key)

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
)
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
):
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(self.base)

    # self.max_position_embeddings here is number of image patches
    # i.e. (image_size // patch_size) ** 2
    num_patches = self.max_position_embeddings
    img_idx = torch.arange(num_patches,
                dtype=torch.int32) \
                .reshape(num_patches, 1)
    img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
    img_idx[-1, -1] = -2  # set to ID_CLS_TOKEN
    num_patches_single_dim = int(math.sqrt(num_patches))
    frequencies_x = img_idx % num_patches_single_dim
    frequencies_y = img_idx // num_patches_single_dim
    freqs_x = ((frequencies_x + 1)[..., None] *
               inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
    freqs_y = ((frequencies_y + 1)[..., None] *
               inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
    freqs = torch.cat([freqs_x, freqs_y],
                      dim=-1).float().contiguous()[..., ::2]
    freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
    cache = torch.view_as_complex(
        torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
    return cache

_compute_inv_freq

_compute_inv_freq(base: float) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_inv_freq(self, base: float) -> torch.Tensor:
    inv_freqs = super()._compute_inv_freq(base)
    inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
    return inv_freqs

forward

forward(
    query: Tensor, key: Optional[Tensor] = None
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding.py
def forward(
    self,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    assert key is not None
    self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
    query_ = torch.view_as_complex(query.float().reshape(
        *query.shape[:-1], -1, 2))
    key_ = torch.view_as_complex(key.float().reshape(
        *key.shape[:-1], -1, 2))
    broadcast_shape = [
        d if i == 1 or i == (query_.ndim - 1) else 1
        for i, d in enumerate(query_.shape)
    ]
    freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
    query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
    key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
    return query_out.type_as(query), key_out.type_as(key)

MRotaryEmbedding

Bases: RotaryEmbedding

Rotary Embedding with Multimodal Sections.

Source code in vllm/model_executor/layers/rotary_embedding.py
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
class MRotaryEmbedding(RotaryEmbedding):
    """Rotary Embedding with Multimodal Sections."""

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
        mrope_section: Optional[list[int]] = None,
    ) -> None:
        # In Qwen2.5-VL, the maximum index value is related to the duration of
        # the input video. We enlarge max_position_embeddings to 4 times to get
        # a larger the cos and sin cache.
        self.cache_max_position_num = max_position_embeddings * 4
        super().__init__(head_size, rotary_dim, self.cache_max_position_num,
                         base, is_neox_style, dtype)

        self.mrope_section = mrope_section
        if self.mrope_section:
            assert sum(self.mrope_section) == rotary_dim // 2

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward().

        Args:
            positions:
                [num_tokens,] (text only) or
                [3, num_tokens] (T/H/W positions with multimodal inputs)
            query: [num_tokens, num_heads * head_size]
            key: [num_tokens, num_kv_heads * head_size]
        """
        assert positions.ndim == 1 or positions.ndim == 2
        assert key is not None

        num_tokens = positions.shape[-1]
        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        if positions.ndim == 2:
            assert self.mrope_section

            cos = torch.cat([
                m[i]
                for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
            ],
                            dim=-1)
            sin = torch.cat([
                m[i]
                for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
            ],
                            dim=-1)

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_size)
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        key_shape = key.shape
        key = key.view(num_tokens, -1, self.head_size)
        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
        key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

    @classmethod
    def get_input_positions(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        second_per_grid_ts: Optional[list[float]],
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
    ) -> tuple[list[list[int]], int]:
        """Get mrope input positions and delta value."""

        image_grid_thw = [] if image_grid_thw is None else image_grid_thw
        video_grid_thw = [] if video_grid_thw is None else video_grid_thw
        second_per_grid_ts = [] if second_per_grid_ts is None else \
            second_per_grid_ts

        llm_positions, mrope_position_delta = \
            cls.get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                context_len=context_len,
                seq_len=seq_len,
                audio_feature_lengths=audio_feature_lengths,
                use_audio_in_video=use_audio_in_video,
            )

        return llm_positions.tolist(), mrope_position_delta

    @classmethod
    def get_input_positions_tensor(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        second_per_grid_ts: list[float],
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        from vllm.transformers_utils.config import thinker_uses_mrope
        if thinker_uses_mrope(hf_config):
            return cls._omni_get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                context_len=context_len,
                seq_len=seq_len,
                audio_feature_lengths=audio_feature_lengths,
                use_audio_in_video=use_audio_in_video,
            )
        elif "glm4v" in hf_config.model_type:
            return cls._glm4v_get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                context_len=context_len,
                seq_len=seq_len,
            )
        else:
            return cls._vl_get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                context_len=context_len,
                seq_len=seq_len,
            )

    @classmethod
    def _glm4v_get_input_positions_tensor(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        context_len: int = 0,
        seq_len: Optional[int] = None,
    ) -> tuple[torch.Tensor, int]:
        """Get mrope input positions and delta value for GLM4V."""

        image_token_id = hf_config.image_token_id
        video_start_token_id = hf_config.video_start_token_id
        video_end_token_id = hf_config.video_end_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        llm_pos_ids_list: list = []

        if not (image_grid_thw is None and video_grid_thw is None):
            if isinstance(image_grid_thw, torch.Tensor):
                image_grid_thw = image_grid_thw.tolist()

            input_token_type: list[str] = []
            video_check_flg = False
            for token in input_tokens:
                if token == video_start_token_id:
                    video_check_flg = True
                elif token == video_end_token_id:
                    video_check_flg = False

                if (token == image_token_id) and (video_check_flg is False):
                    input_token_type.append("image")
                elif (token == image_token_id) and (video_check_flg is True):
                    input_token_type.append("video")
                else:
                    input_token_type.append("text")

            input_type_group: list[tuple[str, int, int]] = []
            for key, group_iter in itertools.groupby(
                    enumerate(input_token_type), lambda x: x[1]):
                group_list = list(group_iter)
                start_index = group_list[0][0]
                end_index = group_list[-1][0] + 1
                input_type_group.append((key, start_index, end_index))

            video_frame_num = 1
            mm_data_idx = 0
            for modality_type, start_idx, end_idx in input_type_group:
                st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                    llm_pos_ids_list) > 0 else 0
                if modality_type == "image":
                    t, h, w = (
                        image_grid_thw[mm_data_idx][0],
                        image_grid_thw[mm_data_idx][1],
                        image_grid_thw[mm_data_idx][2],
                    )
                    llm_grid_t, llm_grid_h, llm_grid_w = \
                        t, h // spatial_merge_size, w // spatial_merge_size

                    t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
                        -1, llm_grid_h * llm_grid_w).flatten()
                    h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
                        llm_grid_t, -1, llm_grid_w).flatten()
                    w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
                        llm_grid_t, llm_grid_h, -1).flatten()
                    llm_pos_ids_list.append(
                        torch.stack([t_index, h_index, w_index]) + st_idx)
                    mm_data_idx += 1

                elif modality_type == "video":
                    t, h, w = (
                        video_frame_num,
                        image_grid_thw[mm_data_idx][1],
                        image_grid_thw[mm_data_idx][2],
                    )
                    llm_grid_t, llm_grid_h, llm_grid_w = \
                        t, h // spatial_merge_size, w // spatial_merge_size

                    for t_idx in range(llm_grid_t):
                        t_index = torch.tensor(t_idx).view(-1, 1).expand(
                            -1, llm_grid_h * llm_grid_w).flatten()
                        h_index = torch.arange(llm_grid_h).view(
                            1, -1, 1).expand(1, -1, llm_grid_w).flatten()
                        w_index = torch.arange(llm_grid_w).view(
                            1, 1, -1).expand(1, llm_grid_h, -1).flatten()
                        llm_pos_ids_list.append(
                            torch.stack([t_index, h_index, w_index]) + st_idx)

                    mm_data_idx += 1
                    video_frame_num += 1

                else:
                    text_len = end_idx - start_idx
                    llm_pos_ids_list.append(
                        torch.arange(text_len).view(1, -1).expand(3, -1) +
                        st_idx)
                    video_frame_num = 1

        else:
            text_len = len(input_tokens)
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1))

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        llm_positions = llm_positions[:, context_len:seq_len]
        mrope_position_delta = (llm_positions.max() + 1 -
                                len(input_tokens)).item()
        return llm_positions, mrope_position_delta

    @classmethod
    def _vl_get_input_positions_tensor(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        second_per_grid_ts: list[float],
        context_len: int = 0,
        seq_len: Optional[int] = None,
    ) -> tuple[torch.Tensor, int]:
        """Get mrope input positions and delta value."""

        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(hf_config.vision_config,
                                    "tokens_per_second", 1.0)

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
            input_tokens_tensor == vision_start_token_id).squeeze(1)
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if image_token_id in input_tokens and remain_images > 0:
                ed_image = input_tokens.index(image_token_id, st)
            else:
                ed_image = len(input_tokens) + 1
            if video_token_id in input_tokens and remain_videos > 0:
                ed_video = input_tokens.index(video_token_id, st)
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
                t, h, w = (
                    image_grid_thw[image_index][0],
                    image_grid_thw[image_index][1],
                    image_grid_thw[image_index][2],
                )
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
                t, h, w = (
                    video_grid_thw[video_index][0],
                    video_grid_thw[video_index][1],
                    video_grid_thw[video_index][2],
                )
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

            llm_grid_t, llm_grid_h, llm_grid_w = \
                t, h // spatial_merge_size, w // spatial_merge_size
            text_len = ed - st

            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

            t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
                -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
                       tokens_per_second).long().flatten()

            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
                llm_grid_t, -1, llm_grid_w).flatten()
            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
                llm_grid_t, llm_grid_h, -1).flatten()
            llm_pos_ids_list.append(
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 -
                                len(input_tokens)).item()
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

    @classmethod
    def _omni_get_input_positions_tensor(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        second_per_grid_ts: Optional[list[float]] = None,
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get mrope input positions and delta value (Qwen2.5-Omni version).

        Differences from MRotaryEmbedding:
            1. Add audio support (and related `audio_feature_lengths`).
            2. Add `use_audio_in_video` option to read audio from video inputs.
                In this case, audio and vision position ids will be split into
                chunks and interleaved.

        Example:

            (V_i are vision position ids, A_i are audio position ids)

            |V_1 ...    V_n|A_1 ...   A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
            |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
        """

        # TODO(fyabc): refactor and share more code with
        #  _vl_get_input_positions_tensor.

        thinker_config = hf_config.thinker_config
        audio_token_id = thinker_config.audio_token_index
        image_token_id = thinker_config.image_token_index
        video_token_id = thinker_config.video_token_index
        audio_start_token_id = thinker_config.audio_start_token_id
        audio_end_token_id = thinker_config.audio_end_token_id
        vision_start_token_id = thinker_config.vision_start_token_id
        vision_end_token_id = thinker_config.vision_end_token_id
        seconds_per_chunk = thinker_config.seconds_per_chunk
        spatial_merge_size = thinker_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(thinker_config.vision_config,
                                    "tokens_per_second", 25)

        if isinstance(image_grid_thw, list):
            image_grid_thw = torch.tensor(image_grid_thw)
        if isinstance(video_grid_thw, list):
            video_grid_thw = torch.tensor(video_grid_thw)

        src_item = input_tokens
        audio_seqlens = audio_feature_lengths
        if not second_per_grid_ts:
            second_per_grid_ts = [1] * video_grid_thw.shape[0]
        audio_idx = 0
        video_idx = 0
        image_idx = 0
        new_src_item: list[int] = []
        llm_pos_ids_list: list[torch.Tensor] = []

        idx = 0
        while idx < len(src_item):
            new_src_item_len = len(new_src_item)
            start_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            if src_item[idx] not in [
                    audio_token_id, video_token_id, image_token_id
            ]:
                if use_audio_in_video and idx > 0:
                    if src_item[idx] == vision_end_token_id and \
                        src_item[idx - 1] == audio_end_token_id:
                        # processing the <|audio_eos|> before <|vision_eos|>
                        start_idx -= 1
                    elif src_item[idx] == audio_start_token_id and \
                        src_item[idx - 1] == vision_start_token_id:
                        # processing the <|audio_bos|> after <|vision_eos|>
                        start_idx -= 1
                new_src_item.append(src_item[idx])
                llm_pos_ids = torch.tensor([start_idx],
                                           dtype=torch.long).expand(3, -1)
                llm_pos_ids_list.append(llm_pos_ids)
            elif src_item[idx] == audio_token_id:
                assert audio_seqlens is not None
                audio_seqlen = audio_seqlens[audio_idx]
                place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
                new_src_item.extend([audio_token_id] * place_num)
                llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
                llm_pos_ids_list.append(llm_pos_ids)
                audio_idx += 1
            elif src_item[idx] == image_token_id:
                grid_t = image_grid_thw[image_idx][0]
                grid_hs = image_grid_thw[:, 1]
                grid_ws = image_grid_thw[:, 2]
                t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
                llm_pos_ids = cls._get_llm_pos_ids_for_vision(
                    start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
                    grid_ws)
                llm_pos_ids_list.append(llm_pos_ids)
                vision_seqlen = image_grid_thw[image_idx].prod() // (
                    spatial_merge_size**2)
                new_src_item.extend([image_token_id] * vision_seqlen)
                image_idx += 1
            elif src_item[idx] == video_token_id and not use_audio_in_video:
                grid_t = video_grid_thw[video_idx][0]
                grid_hs = video_grid_thw[:, 1]
                grid_ws = video_grid_thw[:, 2]
                t_index = (torch.arange(grid_t) *
                           second_per_grid_ts[video_idx] *
                           tokens_per_second).long()
                llm_pos_ids = cls._get_llm_pos_ids_for_vision(
                    start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
                    grid_ws)
                llm_pos_ids_list.append(llm_pos_ids)
                vision_seqlen = video_grid_thw[video_idx].prod() // (
                    spatial_merge_size**2)
                new_src_item.extend([video_token_id] * vision_seqlen)
                video_idx += 1
            else:
                # read audio from video
                assert audio_seqlens is not None
                audio_seqlen = audio_seqlens[audio_idx]
                vision_seqlen = video_grid_thw[video_idx].prod() // (
                    spatial_merge_size**2)
                grid_t = video_grid_thw[video_idx][0]
                grid_h = video_grid_thw[video_idx][1]
                grid_w = video_grid_thw[video_idx][2]
                grid_hs = video_grid_thw[:, 1]
                grid_ws = video_grid_thw[:, 2]
                t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
                t_index = (torch.arange(grid_t) *
                           second_per_grid_ts[video_idx] *
                           tokens_per_second).long()
                t_index_split_chunk = cls._split_list_into_ranges(
                    t_index, t_ntoken_per_chunk)
                place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
                pure_audio_len = place_num - 2
                added_audio_len = 0
                audio_llm_pos_ids_list: list[torch.Tensor] = []
                for t_chunk in t_index_split_chunk:
                    vision_ntoken_per_chunk = len(
                        t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
                    new_src_item.extend([video_token_id] *
                                        vision_ntoken_per_chunk)
                    vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
                        start_idx, video_idx, spatial_merge_size, t_chunk,
                        grid_hs, grid_ws).split(1, dim=1)
                    llm_pos_ids_list.extend(vision_llm_pos_ids_list)
                    new_src_item.extend(
                        min(t_ntoken_per_chunk, pure_audio_len -
                            added_audio_len) * [audio_token_id])
                    audio_start_idx = start_idx if len(
                        audio_llm_pos_ids_list
                    ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
                    if min(t_ntoken_per_chunk,
                           pure_audio_len - added_audio_len) > 0:
                        audio_llm_pos_ids_list = (torch.arange(
                            min(t_ntoken_per_chunk, pure_audio_len -
                                added_audio_len)).expand(3, -1) +
                                                  audio_start_idx).split(1,
                                                                         dim=1)
                    else:
                        audio_llm_pos_ids_list = []
                    added_audio_len += min(t_ntoken_per_chunk,
                                           pure_audio_len - added_audio_len)
                    llm_pos_ids_list.extend(audio_llm_pos_ids_list)
                if added_audio_len < pure_audio_len:
                    new_src_item.extend(
                        (pure_audio_len - added_audio_len) * [audio_token_id])
                    audio_llm_pos_ids_list = (
                        torch.arange(pure_audio_len - added_audio_len).expand(
                            3, -1) + llm_pos_ids_list[-1].max() + 1).split(
                                1, dim=1)
                    llm_pos_ids_list.extend(audio_llm_pos_ids_list)
                audio_idx += 1
                video_idx += 1
            # move to the next token
            idx += len(new_src_item) - new_src_item_len

        llm_positions = torch.cat(llm_pos_ids_list, dim=1)
        mrope_position_delta = torch.cat(llm_pos_ids_list,
                                         dim=1).max() + 1 - len(src_item)
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

    @staticmethod
    def _get_llm_pos_ids_for_vision(
        start_idx: int,
        vision_idx: int,
        spatial_merge_size: int,
        t_index: list[int],
        grid_hs: torch.Tensor,
        grid_ws: torch.Tensor,
    ) -> torch.Tensor:
        llm_pos_ids_list = []
        llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
        llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
        h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
            len(t_index), -1, llm_grid_w).flatten())
        w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
            len(t_index), llm_grid_h, -1).flatten())
        t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
            -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
        _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
        llm_pos_ids_list.append(_llm_pos_ids + start_idx)
        llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
        return llm_pos_ids

    @staticmethod
    def _split_list_into_ranges(lst: torch.Tensor,
                                interval: int) -> list[list[int]]:
        ranges: list[list[int]] = [[]
                                   for _ in range((max(lst) // interval) + 1)]
        for num in lst:
            index = num // interval
            ranges[index].append(num)
        return ranges

    @staticmethod
    def get_next_input_positions(
        mrope_position_delta: int,
        context_len: int,
        seq_len: int,
    ) -> list[list[int]]:
        return [
            list(
                range(context_len + mrope_position_delta,
                      seq_len + mrope_position_delta)) for _ in range(3)
        ]

    @staticmethod
    def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
                                        mrope_position_delta: int,
                                        context_len: int, num_new_tokens: int):

        values = np.arange(mrope_position_delta + context_len,
                           mrope_position_delta + context_len + num_new_tokens,
                           dtype=out.dtype)
        out[:, out_offset:out_offset + num_new_tokens] = values

    @classmethod
    def omni_get_updates_use_audio_in_video(
        cls,
        thinker_config: PretrainedConfig,
        audio_len: int,
        video_grid_thw: Union[list[int], torch.Tensor],
        video_second_per_grid_t: float,
    ) -> list[int]:
        """Get video prompt updates when `use_audio_in_video` is True.

        In this case, audio and vision update ids will be split into
        chunks and interleaved (details in `_omni_get_input_positions_tensor`).

        <|video_bos|><|VIDEO|><|video_eos|> =>
        <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
        """

        audio_token_id = thinker_config.audio_token_index
        video_token_id = thinker_config.video_token_index
        audio_start_token_id = thinker_config.audio_start_token_id
        audio_end_token_id = thinker_config.audio_end_token_id
        seconds_per_chunk = thinker_config.seconds_per_chunk
        spatial_merge_size = thinker_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(thinker_config.vision_config,
                                    "tokens_per_second", 25)

        grid_t = video_grid_thw[0]
        grid_h = video_grid_thw[1]
        grid_w = video_grid_thw[2]
        t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
        t_index = (torch.arange(grid_t) * video_second_per_grid_t *
                   tokens_per_second).long()
        t_index_split_chunk = cls._split_list_into_ranges(
            t_index, t_ntoken_per_chunk)

        updates = [audio_start_token_id]
        added_audio_len = 0
        for t_chunk in t_index_split_chunk:
            vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
                spatial_merge_size**2)
            updates.extend([video_token_id] * vision_ntoken_per_chunk)

            audio_chunk_size = min(t_ntoken_per_chunk,
                                   audio_len - added_audio_len)
            updates.extend(audio_chunk_size * [audio_token_id])
            added_audio_len += audio_chunk_size
        if added_audio_len < audio_len:
            updates.extend((audio_len - added_audio_len) * [audio_token_id])
        updates.extend([audio_end_token_id])

        return updates

cache_max_position_num instance-attribute

cache_max_position_num = max_position_embeddings * 4

mrope_section instance-attribute

mrope_section = mrope_section

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
    mrope_section: Optional[list[int]] = None,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
    mrope_section: Optional[list[int]] = None,
) -> None:
    # In Qwen2.5-VL, the maximum index value is related to the duration of
    # the input video. We enlarge max_position_embeddings to 4 times to get
    # a larger the cos and sin cache.
    self.cache_max_position_num = max_position_embeddings * 4
    super().__init__(head_size, rotary_dim, self.cache_max_position_num,
                     base, is_neox_style, dtype)

    self.mrope_section = mrope_section
    if self.mrope_section:
        assert sum(self.mrope_section) == rotary_dim // 2

_get_llm_pos_ids_for_vision staticmethod

_get_llm_pos_ids_for_vision(
    start_idx: int,
    vision_idx: int,
    spatial_merge_size: int,
    t_index: list[int],
    grid_hs: Tensor,
    grid_ws: Tensor,
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
@staticmethod
def _get_llm_pos_ids_for_vision(
    start_idx: int,
    vision_idx: int,
    spatial_merge_size: int,
    t_index: list[int],
    grid_hs: torch.Tensor,
    grid_ws: torch.Tensor,
) -> torch.Tensor:
    llm_pos_ids_list = []
    llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
    llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
    h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
        len(t_index), -1, llm_grid_w).flatten())
    w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
        len(t_index), llm_grid_h, -1).flatten())
    t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
        -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
    _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
    llm_pos_ids_list.append(_llm_pos_ids + start_idx)
    llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
    return llm_pos_ids

_glm4v_get_input_positions_tensor classmethod

_glm4v_get_input_positions_tensor(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], Tensor],
    video_grid_thw: Union[list[list[int]], Tensor],
    context_len: int = 0,
    seq_len: Optional[int] = None,
) -> tuple[Tensor, int]

Get mrope input positions and delta value for GLM4V.

Source code in vllm/model_executor/layers/rotary_embedding.py
@classmethod
def _glm4v_get_input_positions_tensor(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], torch.Tensor],
    video_grid_thw: Union[list[list[int]], torch.Tensor],
    context_len: int = 0,
    seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
    """Get mrope input positions and delta value for GLM4V."""

    image_token_id = hf_config.image_token_id
    video_start_token_id = hf_config.video_start_token_id
    video_end_token_id = hf_config.video_end_token_id
    spatial_merge_size = hf_config.vision_config.spatial_merge_size
    llm_pos_ids_list: list = []

    if not (image_grid_thw is None and video_grid_thw is None):
        if isinstance(image_grid_thw, torch.Tensor):
            image_grid_thw = image_grid_thw.tolist()

        input_token_type: list[str] = []
        video_check_flg = False
        for token in input_tokens:
            if token == video_start_token_id:
                video_check_flg = True
            elif token == video_end_token_id:
                video_check_flg = False

            if (token == image_token_id) and (video_check_flg is False):
                input_token_type.append("image")
            elif (token == image_token_id) and (video_check_flg is True):
                input_token_type.append("video")
            else:
                input_token_type.append("text")

        input_type_group: list[tuple[str, int, int]] = []
        for key, group_iter in itertools.groupby(
                enumerate(input_token_type), lambda x: x[1]):
            group_list = list(group_iter)
            start_index = group_list[0][0]
            end_index = group_list[-1][0] + 1
            input_type_group.append((key, start_index, end_index))

        video_frame_num = 1
        mm_data_idx = 0
        for modality_type, start_idx, end_idx in input_type_group:
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            if modality_type == "image":
                t, h, w = (
                    image_grid_thw[mm_data_idx][0],
                    image_grid_thw[mm_data_idx][1],
                    image_grid_thw[mm_data_idx][2],
                )
                llm_grid_t, llm_grid_h, llm_grid_w = \
                    t, h // spatial_merge_size, w // spatial_merge_size

                t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
                    -1, llm_grid_h * llm_grid_w).flatten()
                h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
                    llm_grid_t, -1, llm_grid_w).flatten()
                w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
                    llm_grid_t, llm_grid_h, -1).flatten()
                llm_pos_ids_list.append(
                    torch.stack([t_index, h_index, w_index]) + st_idx)
                mm_data_idx += 1

            elif modality_type == "video":
                t, h, w = (
                    video_frame_num,
                    image_grid_thw[mm_data_idx][1],
                    image_grid_thw[mm_data_idx][2],
                )
                llm_grid_t, llm_grid_h, llm_grid_w = \
                    t, h // spatial_merge_size, w // spatial_merge_size

                for t_idx in range(llm_grid_t):
                    t_index = torch.tensor(t_idx).view(-1, 1).expand(
                        -1, llm_grid_h * llm_grid_w).flatten()
                    h_index = torch.arange(llm_grid_h).view(
                        1, -1, 1).expand(1, -1, llm_grid_w).flatten()
                    w_index = torch.arange(llm_grid_w).view(
                        1, 1, -1).expand(1, llm_grid_h, -1).flatten()
                    llm_pos_ids_list.append(
                        torch.stack([t_index, h_index, w_index]) + st_idx)

                mm_data_idx += 1
                video_frame_num += 1

            else:
                text_len = end_idx - start_idx
                llm_pos_ids_list.append(
                    torch.arange(text_len).view(1, -1).expand(3, -1) +
                    st_idx)
                video_frame_num = 1

    else:
        text_len = len(input_tokens)
        llm_pos_ids_list.append(
            torch.arange(text_len).view(1, -1).expand(3, -1))

    llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
    llm_positions = llm_positions[:, context_len:seq_len]
    mrope_position_delta = (llm_positions.max() + 1 -
                            len(input_tokens)).item()
    return llm_positions, mrope_position_delta

_omni_get_input_positions_tensor classmethod

_omni_get_input_positions_tensor(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], Tensor],
    video_grid_thw: Union[list[list[int]], Tensor],
    second_per_grid_ts: Optional[list[float]] = None,
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[Tensor, int]

Get mrope input positions and delta value (Qwen2.5-Omni version).

Differences from MRotaryEmbedding
  1. Add audio support (and related audio_feature_lengths).
  2. Add use_audio_in_video option to read audio from video inputs. In this case, audio and vision position ids will be split into chunks and interleaved.

Example:

(V_i are vision position ids, A_i are audio position ids)

|V_1 ...    V_n|A_1 ...   A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
Source code in vllm/model_executor/layers/rotary_embedding.py
@classmethod
def _omni_get_input_positions_tensor(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], torch.Tensor],
    video_grid_thw: Union[list[list[int]], torch.Tensor],
    second_per_grid_ts: Optional[list[float]] = None,
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[torch.Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
    """Get mrope input positions and delta value (Qwen2.5-Omni version).

    Differences from MRotaryEmbedding:
        1. Add audio support (and related `audio_feature_lengths`).
        2. Add `use_audio_in_video` option to read audio from video inputs.
            In this case, audio and vision position ids will be split into
            chunks and interleaved.

    Example:

        (V_i are vision position ids, A_i are audio position ids)

        |V_1 ...    V_n|A_1 ...   A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
        |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
    """

    # TODO(fyabc): refactor and share more code with
    #  _vl_get_input_positions_tensor.

    thinker_config = hf_config.thinker_config
    audio_token_id = thinker_config.audio_token_index
    image_token_id = thinker_config.image_token_index
    video_token_id = thinker_config.video_token_index
    audio_start_token_id = thinker_config.audio_start_token_id
    audio_end_token_id = thinker_config.audio_end_token_id
    vision_start_token_id = thinker_config.vision_start_token_id
    vision_end_token_id = thinker_config.vision_end_token_id
    seconds_per_chunk = thinker_config.seconds_per_chunk
    spatial_merge_size = thinker_config.vision_config.spatial_merge_size
    tokens_per_second = getattr(thinker_config.vision_config,
                                "tokens_per_second", 25)

    if isinstance(image_grid_thw, list):
        image_grid_thw = torch.tensor(image_grid_thw)
    if isinstance(video_grid_thw, list):
        video_grid_thw = torch.tensor(video_grid_thw)

    src_item = input_tokens
    audio_seqlens = audio_feature_lengths
    if not second_per_grid_ts:
        second_per_grid_ts = [1] * video_grid_thw.shape[0]
    audio_idx = 0
    video_idx = 0
    image_idx = 0
    new_src_item: list[int] = []
    llm_pos_ids_list: list[torch.Tensor] = []

    idx = 0
    while idx < len(src_item):
        new_src_item_len = len(new_src_item)
        start_idx = llm_pos_ids_list[-1].max() + 1 if len(
            llm_pos_ids_list) > 0 else 0
        if src_item[idx] not in [
                audio_token_id, video_token_id, image_token_id
        ]:
            if use_audio_in_video and idx > 0:
                if src_item[idx] == vision_end_token_id and \
                    src_item[idx - 1] == audio_end_token_id:
                    # processing the <|audio_eos|> before <|vision_eos|>
                    start_idx -= 1
                elif src_item[idx] == audio_start_token_id and \
                    src_item[idx - 1] == vision_start_token_id:
                    # processing the <|audio_bos|> after <|vision_eos|>
                    start_idx -= 1
            new_src_item.append(src_item[idx])
            llm_pos_ids = torch.tensor([start_idx],
                                       dtype=torch.long).expand(3, -1)
            llm_pos_ids_list.append(llm_pos_ids)
        elif src_item[idx] == audio_token_id:
            assert audio_seqlens is not None
            audio_seqlen = audio_seqlens[audio_idx]
            place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
            new_src_item.extend([audio_token_id] * place_num)
            llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
            llm_pos_ids_list.append(llm_pos_ids)
            audio_idx += 1
        elif src_item[idx] == image_token_id:
            grid_t = image_grid_thw[image_idx][0]
            grid_hs = image_grid_thw[:, 1]
            grid_ws = image_grid_thw[:, 2]
            t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
            llm_pos_ids = cls._get_llm_pos_ids_for_vision(
                start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
                grid_ws)
            llm_pos_ids_list.append(llm_pos_ids)
            vision_seqlen = image_grid_thw[image_idx].prod() // (
                spatial_merge_size**2)
            new_src_item.extend([image_token_id] * vision_seqlen)
            image_idx += 1
        elif src_item[idx] == video_token_id and not use_audio_in_video:
            grid_t = video_grid_thw[video_idx][0]
            grid_hs = video_grid_thw[:, 1]
            grid_ws = video_grid_thw[:, 2]
            t_index = (torch.arange(grid_t) *
                       second_per_grid_ts[video_idx] *
                       tokens_per_second).long()
            llm_pos_ids = cls._get_llm_pos_ids_for_vision(
                start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
                grid_ws)
            llm_pos_ids_list.append(llm_pos_ids)
            vision_seqlen = video_grid_thw[video_idx].prod() // (
                spatial_merge_size**2)
            new_src_item.extend([video_token_id] * vision_seqlen)
            video_idx += 1
        else:
            # read audio from video
            assert audio_seqlens is not None
            audio_seqlen = audio_seqlens[audio_idx]
            vision_seqlen = video_grid_thw[video_idx].prod() // (
                spatial_merge_size**2)
            grid_t = video_grid_thw[video_idx][0]
            grid_h = video_grid_thw[video_idx][1]
            grid_w = video_grid_thw[video_idx][2]
            grid_hs = video_grid_thw[:, 1]
            grid_ws = video_grid_thw[:, 2]
            t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
            t_index = (torch.arange(grid_t) *
                       second_per_grid_ts[video_idx] *
                       tokens_per_second).long()
            t_index_split_chunk = cls._split_list_into_ranges(
                t_index, t_ntoken_per_chunk)
            place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
            pure_audio_len = place_num - 2
            added_audio_len = 0
            audio_llm_pos_ids_list: list[torch.Tensor] = []
            for t_chunk in t_index_split_chunk:
                vision_ntoken_per_chunk = len(
                    t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
                new_src_item.extend([video_token_id] *
                                    vision_ntoken_per_chunk)
                vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
                    start_idx, video_idx, spatial_merge_size, t_chunk,
                    grid_hs, grid_ws).split(1, dim=1)
                llm_pos_ids_list.extend(vision_llm_pos_ids_list)
                new_src_item.extend(
                    min(t_ntoken_per_chunk, pure_audio_len -
                        added_audio_len) * [audio_token_id])
                audio_start_idx = start_idx if len(
                    audio_llm_pos_ids_list
                ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
                if min(t_ntoken_per_chunk,
                       pure_audio_len - added_audio_len) > 0:
                    audio_llm_pos_ids_list = (torch.arange(
                        min(t_ntoken_per_chunk, pure_audio_len -
                            added_audio_len)).expand(3, -1) +
                                              audio_start_idx).split(1,
                                                                     dim=1)
                else:
                    audio_llm_pos_ids_list = []
                added_audio_len += min(t_ntoken_per_chunk,
                                       pure_audio_len - added_audio_len)
                llm_pos_ids_list.extend(audio_llm_pos_ids_list)
            if added_audio_len < pure_audio_len:
                new_src_item.extend(
                    (pure_audio_len - added_audio_len) * [audio_token_id])
                audio_llm_pos_ids_list = (
                    torch.arange(pure_audio_len - added_audio_len).expand(
                        3, -1) + llm_pos_ids_list[-1].max() + 1).split(
                            1, dim=1)
                llm_pos_ids_list.extend(audio_llm_pos_ids_list)
            audio_idx += 1
            video_idx += 1
        # move to the next token
        idx += len(new_src_item) - new_src_item_len

    llm_positions = torch.cat(llm_pos_ids_list, dim=1)
    mrope_position_delta = torch.cat(llm_pos_ids_list,
                                     dim=1).max() + 1 - len(src_item)
    llm_positions = llm_positions[:, context_len:seq_len]

    return llm_positions, mrope_position_delta

_split_list_into_ranges staticmethod

_split_list_into_ranges(
    lst: Tensor, interval: int
) -> list[list[int]]
Source code in vllm/model_executor/layers/rotary_embedding.py
@staticmethod
def _split_list_into_ranges(lst: torch.Tensor,
                            interval: int) -> list[list[int]]:
    ranges: list[list[int]] = [[]
                               for _ in range((max(lst) // interval) + 1)]
    for num in lst:
        index = num // interval
        ranges[index].append(num)
    return ranges

_vl_get_input_positions_tensor classmethod

_vl_get_input_positions_tensor(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], Tensor],
    video_grid_thw: Union[list[list[int]], Tensor],
    second_per_grid_ts: list[float],
    context_len: int = 0,
    seq_len: Optional[int] = None,
) -> tuple[Tensor, int]

Get mrope input positions and delta value.

Source code in vllm/model_executor/layers/rotary_embedding.py
@classmethod
def _vl_get_input_positions_tensor(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], torch.Tensor],
    video_grid_thw: Union[list[list[int]], torch.Tensor],
    second_per_grid_ts: list[float],
    context_len: int = 0,
    seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
    """Get mrope input positions and delta value."""

    image_token_id = hf_config.image_token_id
    video_token_id = hf_config.video_token_id
    vision_start_token_id = hf_config.vision_start_token_id
    spatial_merge_size = hf_config.vision_config.spatial_merge_size
    tokens_per_second = getattr(hf_config.vision_config,
                                "tokens_per_second", 1.0)

    input_tokens_tensor = torch.tensor(input_tokens)
    vision_start_indices = torch.argwhere(
        input_tokens_tensor == vision_start_token_id).squeeze(1)
    vision_tokens = input_tokens_tensor[vision_start_indices + 1]
    image_nums = (vision_tokens == image_token_id).sum()
    video_nums = (vision_tokens == video_token_id).sum()
    llm_pos_ids_list: list = []

    st = 0
    remain_images, remain_videos = image_nums, video_nums

    image_index, video_index = 0, 0
    for _ in range(image_nums + video_nums):
        video_second_per_grid_t = 0.0
        if image_token_id in input_tokens and remain_images > 0:
            ed_image = input_tokens.index(image_token_id, st)
        else:
            ed_image = len(input_tokens) + 1
        if video_token_id in input_tokens and remain_videos > 0:
            ed_video = input_tokens.index(video_token_id, st)
        else:
            ed_video = len(input_tokens) + 1
        if ed_image < ed_video:
            t, h, w = (
                image_grid_thw[image_index][0],
                image_grid_thw[image_index][1],
                image_grid_thw[image_index][2],
            )
            image_index += 1
            remain_images -= 1
            ed = ed_image
        else:
            t, h, w = (
                video_grid_thw[video_index][0],
                video_grid_thw[video_index][1],
                video_grid_thw[video_index][2],
            )
            video_second_per_grid_t = 1.0
            if second_per_grid_ts:
                video_second_per_grid_t = second_per_grid_ts[video_index]
            video_index += 1
            remain_videos -= 1
            ed = ed_video

        llm_grid_t, llm_grid_h, llm_grid_w = \
            t, h // spatial_merge_size, w // spatial_merge_size
        text_len = ed - st

        st_idx = llm_pos_ids_list[-1].max() + 1 if len(
            llm_pos_ids_list) > 0 else 0
        llm_pos_ids_list.append(
            torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

        t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
            -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
                   tokens_per_second).long().flatten()

        h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
            llm_grid_t, -1, llm_grid_w).flatten()
        w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
            llm_grid_t, llm_grid_h, -1).flatten()
        llm_pos_ids_list.append(
            torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
        st = ed + llm_grid_t * llm_grid_h * llm_grid_w

    if st < len(input_tokens):
        st_idx = llm_pos_ids_list[-1].max() + 1 if len(
            llm_pos_ids_list) > 0 else 0
        text_len = len(input_tokens) - st
        llm_pos_ids_list.append(
            torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

    llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
    mrope_position_delta = (llm_positions.max() + 1 -
                            len(input_tokens)).item()
    llm_positions = llm_positions[:, context_len:seq_len]

    return llm_positions, mrope_position_delta

forward

forward(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]

PyTorch-native implementation equivalent to forward().

Parameters:

Name Type Description Default
positions Tensor

[num_tokens,] (text only) or [3, num_tokens] (T/H/W positions with multimodal inputs)

required
query Tensor

[num_tokens, num_heads * head_size]

required
key Optional[Tensor]

[num_tokens, num_kv_heads * head_size]

None
Source code in vllm/model_executor/layers/rotary_embedding.py
def forward(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    """PyTorch-native implementation equivalent to forward().

    Args:
        positions:
            [num_tokens,] (text only) or
            [3, num_tokens] (T/H/W positions with multimodal inputs)
        query: [num_tokens, num_heads * head_size]
        key: [num_tokens, num_kv_heads * head_size]
    """
    assert positions.ndim == 1 or positions.ndim == 2
    assert key is not None

    num_tokens = positions.shape[-1]
    cos_sin = self.cos_sin_cache[positions]
    cos, sin = cos_sin.chunk(2, dim=-1)
    if positions.ndim == 2:
        assert self.mrope_section

        cos = torch.cat([
            m[i]
            for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
        ],
                        dim=-1)
        sin = torch.cat([
            m[i]
            for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
        ],
                        dim=-1)

    query_shape = query.shape
    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., :self.rotary_dim]
    query_pass = query[..., self.rotary_dim:]
    query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    key_shape = key.shape
    key = key.view(num_tokens, -1, self.head_size)
    key_rot = key[..., :self.rotary_dim]
    key_pass = key[..., self.rotary_dim:]
    key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
    key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

get_input_positions classmethod

get_input_positions(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Optional[
        Union[list[list[int]], Tensor]
    ],
    video_grid_thw: Optional[
        Union[list[list[int]], Tensor]
    ],
    second_per_grid_ts: Optional[list[float]],
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[list[list[int]], int]

Get mrope input positions and delta value.

Source code in vllm/model_executor/layers/rotary_embedding.py
@classmethod
def get_input_positions(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
    video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
    second_per_grid_ts: Optional[list[float]],
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[torch.Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[list[list[int]], int]:
    """Get mrope input positions and delta value."""

    image_grid_thw = [] if image_grid_thw is None else image_grid_thw
    video_grid_thw = [] if video_grid_thw is None else video_grid_thw
    second_per_grid_ts = [] if second_per_grid_ts is None else \
        second_per_grid_ts

    llm_positions, mrope_position_delta = \
        cls.get_input_positions_tensor(
            input_tokens=input_tokens,
            hf_config=hf_config,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            second_per_grid_ts=second_per_grid_ts,
            context_len=context_len,
            seq_len=seq_len,
            audio_feature_lengths=audio_feature_lengths,
            use_audio_in_video=use_audio_in_video,
        )

    return llm_positions.tolist(), mrope_position_delta

get_input_positions_tensor classmethod

get_input_positions_tensor(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], Tensor],
    video_grid_thw: Union[list[list[int]], Tensor],
    second_per_grid_ts: list[float],
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[Tensor, int]
Source code in vllm/model_executor/layers/rotary_embedding.py
@classmethod
def get_input_positions_tensor(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], torch.Tensor],
    video_grid_thw: Union[list[list[int]], torch.Tensor],
    second_per_grid_ts: list[float],
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[torch.Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
    from vllm.transformers_utils.config import thinker_uses_mrope
    if thinker_uses_mrope(hf_config):
        return cls._omni_get_input_positions_tensor(
            input_tokens=input_tokens,
            hf_config=hf_config,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            second_per_grid_ts=second_per_grid_ts,
            context_len=context_len,
            seq_len=seq_len,
            audio_feature_lengths=audio_feature_lengths,
            use_audio_in_video=use_audio_in_video,
        )
    elif "glm4v" in hf_config.model_type:
        return cls._glm4v_get_input_positions_tensor(
            input_tokens=input_tokens,
            hf_config=hf_config,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            context_len=context_len,
            seq_len=seq_len,
        )
    else:
        return cls._vl_get_input_positions_tensor(
            input_tokens=input_tokens,
            hf_config=hf_config,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            second_per_grid_ts=second_per_grid_ts,
            context_len=context_len,
            seq_len=seq_len,
        )

get_next_input_positions staticmethod

get_next_input_positions(
    mrope_position_delta: int,
    context_len: int,
    seq_len: int,
) -> list[list[int]]
Source code in vllm/model_executor/layers/rotary_embedding.py
@staticmethod
def get_next_input_positions(
    mrope_position_delta: int,
    context_len: int,
    seq_len: int,
) -> list[list[int]]:
    return [
        list(
            range(context_len + mrope_position_delta,
                  seq_len + mrope_position_delta)) for _ in range(3)
    ]

get_next_input_positions_tensor staticmethod

get_next_input_positions_tensor(
    out: ndarray,
    out_offset: int,
    mrope_position_delta: int,
    context_len: int,
    num_new_tokens: int,
)
Source code in vllm/model_executor/layers/rotary_embedding.py
@staticmethod
def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
                                    mrope_position_delta: int,
                                    context_len: int, num_new_tokens: int):

    values = np.arange(mrope_position_delta + context_len,
                       mrope_position_delta + context_len + num_new_tokens,
                       dtype=out.dtype)
    out[:, out_offset:out_offset + num_new_tokens] = values

omni_get_updates_use_audio_in_video classmethod

omni_get_updates_use_audio_in_video(
    thinker_config: PretrainedConfig,
    audio_len: int,
    video_grid_thw: Union[list[int], Tensor],
    video_second_per_grid_t: float,
) -> list[int]

Get video prompt updates when use_audio_in_video is True.

In this case, audio and vision update ids will be split into chunks and interleaved (details in _omni_get_input_positions_tensor).

<|video_bos|><|VIDEO|><|video_eos|> => <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>

Source code in vllm/model_executor/layers/rotary_embedding.py
@classmethod
def omni_get_updates_use_audio_in_video(
    cls,
    thinker_config: PretrainedConfig,
    audio_len: int,
    video_grid_thw: Union[list[int], torch.Tensor],
    video_second_per_grid_t: float,
) -> list[int]:
    """Get video prompt updates when `use_audio_in_video` is True.

    In this case, audio and vision update ids will be split into
    chunks and interleaved (details in `_omni_get_input_positions_tensor`).

    <|video_bos|><|VIDEO|><|video_eos|> =>
    <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
    """

    audio_token_id = thinker_config.audio_token_index
    video_token_id = thinker_config.video_token_index
    audio_start_token_id = thinker_config.audio_start_token_id
    audio_end_token_id = thinker_config.audio_end_token_id
    seconds_per_chunk = thinker_config.seconds_per_chunk
    spatial_merge_size = thinker_config.vision_config.spatial_merge_size
    tokens_per_second = getattr(thinker_config.vision_config,
                                "tokens_per_second", 25)

    grid_t = video_grid_thw[0]
    grid_h = video_grid_thw[1]
    grid_w = video_grid_thw[2]
    t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
    t_index = (torch.arange(grid_t) * video_second_per_grid_t *
               tokens_per_second).long()
    t_index_split_chunk = cls._split_list_into_ranges(
        t_index, t_ntoken_per_chunk)

    updates = [audio_start_token_id]
    added_audio_len = 0
    for t_chunk in t_index_split_chunk:
        vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
            spatial_merge_size**2)
        updates.extend([video_token_id] * vision_ntoken_per_chunk)

        audio_chunk_size = min(t_ntoken_per_chunk,
                               audio_len - added_audio_len)
        updates.extend(audio_chunk_size * [audio_token_id])
        added_audio_len += audio_chunk_size
    if added_audio_len < audio_len:
        updates.extend((audio_len - added_audio_len) * [audio_token_id])
    updates.extend([audio_end_token_id])

    return updates

NTKScalingRotaryEmbedding

Bases: RotaryEmbedding

RotaryEmbedding extended with fixed and mixed NTK scaling. https://kexue.fm/archives/9706

Source code in vllm/model_executor/layers/rotary_embedding.py
class NTKScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with fixed and mixed NTK scaling.
    https://kexue.fm/archives/9706 """

    def __init__(self,
                 head_size: int,
                 rotary_dim: int,
                 max_position_embeddings: int,
                 base: float,
                 is_neox_style: bool,
                 scaling_factor: float,
                 dtype: torch.dtype,
                 mixed_b: Optional[float] = None) -> None:
        self.scaling_factor = scaling_factor
        self.mixed_b = mixed_b
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_inv_freq(self, base: float) -> torch.Tensor:
        base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
        inv_freq = super()._compute_inv_freq(base)

        if self.mixed_b is None:
            inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim)
        else:
            a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim /
                                                           2)**self.mixed_b
            lambda_1_m = (a * torch.arange(
                1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp()
            inv_freq = inv_freq / lambda_1_m

        return inv_freq

mixed_b instance-attribute

mixed_b = mixed_b

scaling_factor instance-attribute

scaling_factor = scaling_factor

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: dtype,
    mixed_b: Optional[float] = None,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(self,
             head_size: int,
             rotary_dim: int,
             max_position_embeddings: int,
             base: float,
             is_neox_style: bool,
             scaling_factor: float,
             dtype: torch.dtype,
             mixed_b: Optional[float] = None) -> None:
    self.scaling_factor = scaling_factor
    self.mixed_b = mixed_b
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)

_compute_inv_freq

_compute_inv_freq(base: float) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_inv_freq(self, base: float) -> torch.Tensor:
    base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
    inv_freq = super()._compute_inv_freq(base)

    if self.mixed_b is None:
        inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim)
    else:
        a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim /
                                                       2)**self.mixed_b
        lambda_1_m = (a * torch.arange(
            1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp()
        inv_freq = inv_freq / lambda_1_m

    return inv_freq

Phi3LongRoPEScaledRotaryEmbedding

Bases: Module

Phi3 family of models scaled rotary embedding.

Based on the original RotaryEmbedding implementation.

Source code in vllm/model_executor/layers/rotary_embedding.py
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
    """Phi3 family of models scaled rotary embedding.

    Based on the original RotaryEmbedding implementation.
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        original_max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
        short_factor: list[float],
        long_factor: list[float],
        short_mscale: Optional[float] = None,
        long_mscale: Optional[float] = None,
    ):
        super().__init__()

        if is_neox_style is False:
            raise ValueError(
                "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
            )

        self.rotary_dim = rotary_dim
        self.head_size = head_size
        self.max_position_embeddings = max_position_embeddings
        self.original_max_position_embeddings = original_max_position_embeddings
        self.base = base
        self.short_factor = short_factor
        self.long_factor = long_factor

        scale = self.max_position_embeddings / \
                self.original_max_position_embeddings
        if scale <= 1.0:
            scaling_factor = 1.0
        else:
            scaling_factor = math.sqrt(
                1 + math.log(scale) /
                math.log(self.original_max_position_embeddings))
        if short_mscale is None:
            short_mscale = scaling_factor
        if long_mscale is None:
            long_mscale = scaling_factor

        self.short_mscale = short_mscale
        self.long_mscale = long_mscale

        short_cache = self._compute_cos_sin_cache(
            original_max_position_embeddings, short_factor, short_mscale)
        short_cache = short_cache.to(dtype)

        long_cache = self._compute_cos_sin_cache(max_position_embeddings,
                                                 long_factor, long_mscale)
        long_cache = long_cache.to(dtype)

        long_short_cache = torch.cat([short_cache, long_cache], dim=0)
        self.register_buffer("long_short_cos_sin_cache",
                             long_short_cache,
                             persistent=False)

    def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor:
        rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
        inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
            0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
        return inv_freq

    def _compute_cos_sin_cache(
        self,
        max_position_embeddings: int,
        rescale_factors: list[float],
        mscale: float,
    ) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(rescale_factors)
        t = torch.arange(max_position_embeddings, dtype=torch.float)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos() * mscale
        sin = freqs.sin() * mscale
        cache = torch.cat((cos, sin), dim=-1)
        return cache

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        assert key is not None
        query = query.view(*query.shape[:-1], -1, self.head_size)
        key = key.view(*key.shape[:-1], -1, self.head_size)

        k = self.original_max_position_embeddings
        long_prompt_offset = (torch.any(positions > k).float() *
                              torch.full_like(positions, k)).long()
        idx = (torch.add(positions, long_prompt_offset)
               if long_prompt_offset is not None else positions)
        idx = torch.add(idx, offsets) if offsets is not None else idx
        cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

        cos, sin = cos_sin.chunk(2, dim=-1)
        cos = cos.repeat(1, 2).unsqueeze(-2)
        sin = sin.repeat(1, 2).unsqueeze(-2)

        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
        query = torch.cat((query_rot, query_pass), dim=-1)

        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
        key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
        key = torch.cat((key_rot, key_pass), dim=-1)

        return query.flatten(-2), key.flatten(-2)

base instance-attribute

base = base

head_size instance-attribute

head_size = head_size

long_factor instance-attribute

long_factor = long_factor

long_mscale instance-attribute

long_mscale = long_mscale

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

original_max_position_embeddings instance-attribute

original_max_position_embeddings = (
    original_max_position_embeddings
)

rotary_dim instance-attribute

rotary_dim = rotary_dim

short_factor instance-attribute

short_factor = short_factor

short_mscale instance-attribute

short_mscale = short_mscale

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    original_max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
    short_factor: list[float],
    long_factor: list[float],
    short_mscale: Optional[float] = None,
    long_mscale: Optional[float] = None,
)
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    original_max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
    short_factor: list[float],
    long_factor: list[float],
    short_mscale: Optional[float] = None,
    long_mscale: Optional[float] = None,
):
    super().__init__()

    if is_neox_style is False:
        raise ValueError(
            "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
        )

    self.rotary_dim = rotary_dim
    self.head_size = head_size
    self.max_position_embeddings = max_position_embeddings
    self.original_max_position_embeddings = original_max_position_embeddings
    self.base = base
    self.short_factor = short_factor
    self.long_factor = long_factor

    scale = self.max_position_embeddings / \
            self.original_max_position_embeddings
    if scale <= 1.0:
        scaling_factor = 1.0
    else:
        scaling_factor = math.sqrt(
            1 + math.log(scale) /
            math.log(self.original_max_position_embeddings))
    if short_mscale is None:
        short_mscale = scaling_factor
    if long_mscale is None:
        long_mscale = scaling_factor

    self.short_mscale = short_mscale
    self.long_mscale = long_mscale

    short_cache = self._compute_cos_sin_cache(
        original_max_position_embeddings, short_factor, short_mscale)
    short_cache = short_cache.to(dtype)

    long_cache = self._compute_cos_sin_cache(max_position_embeddings,
                                             long_factor, long_mscale)
    long_cache = long_cache.to(dtype)

    long_short_cache = torch.cat([short_cache, long_cache], dim=0)
    self.register_buffer("long_short_cos_sin_cache",
                         long_short_cache,
                         persistent=False)

_compute_cos_sin_cache

_compute_cos_sin_cache(
    max_position_embeddings: int,
    rescale_factors: list[float],
    mscale: float,
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_cos_sin_cache(
    self,
    max_position_embeddings: int,
    rescale_factors: list[float],
    mscale: float,
) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(rescale_factors)
    t = torch.arange(max_position_embeddings, dtype=torch.float)
    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = freqs.cos() * mscale
    sin = freqs.sin() * mscale
    cache = torch.cat((cos, sin), dim=-1)
    return cache

_compute_inv_freq

_compute_inv_freq(rescale_factors: list[float]) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor:
    rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
    inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
        0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
    return inv_freq

forward

forward(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding.py
def forward(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    assert key is not None
    query = query.view(*query.shape[:-1], -1, self.head_size)
    key = key.view(*key.shape[:-1], -1, self.head_size)

    k = self.original_max_position_embeddings
    long_prompt_offset = (torch.any(positions > k).float() *
                          torch.full_like(positions, k)).long()
    idx = (torch.add(positions, long_prompt_offset)
           if long_prompt_offset is not None else positions)
    idx = torch.add(idx, offsets) if offsets is not None else idx
    cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

    cos, sin = cos_sin.chunk(2, dim=-1)
    cos = cos.repeat(1, 2).unsqueeze(-2)
    sin = sin.repeat(1, 2).unsqueeze(-2)

    query_rot = query[..., :self.rotary_dim]
    query_pass = query[..., self.rotary_dim:]
    query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
    query = torch.cat((query_rot, query_pass), dim=-1)

    key_rot = key[..., :self.rotary_dim]
    key_pass = key[..., self.rotary_dim:]
    key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
    key = torch.cat((key_rot, key_pass), dim=-1)

    return query.flatten(-2), key.flatten(-2)

RotaryEmbedding

Bases: CustomOp

Original rotary positional embedding.

Source code in vllm/model_executor/layers/rotary_embedding.py
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
    """Original rotary positional embedding."""

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
    ) -> None:
        super().__init__()
        self.head_size = head_size
        self.rotary_dim = rotary_dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.is_neox_style = is_neox_style
        self.dtype = dtype

        cache = self._compute_cos_sin_cache()
        cache = cache.to(dtype)
        self.cos_sin_cache: torch.Tensor
        self.register_buffer("cos_sin_cache", cache, persistent=False)

    def _compute_inv_freq(self, base: float) -> torch.Tensor:
        """Compute the inverse frequency."""
        # NOTE(woosuk): To exactly match the HF implementation, we need to
        # use CPU to compute the cache and then move it to GPU. However, we
        # create the cache on GPU for faster initialization. This may cause
        # a slight numerical difference between the HF implementation and ours.
        inv_freq = 1.0 / (base**(torch.arange(
            0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        """Compute the cos and sin cache."""
        inv_freq = self._compute_inv_freq(self.base)
        t = torch.arange(self.max_position_embeddings, dtype=torch.float)

        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache

    def forward_native(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """A PyTorch-native implementation of forward()."""
        if offsets is not None:
            positions = positions + offsets
        positions = positions.flatten()
        num_tokens = positions.shape[0]
        cos_sin = self.cos_sin_cache.index_select(0, positions)
        cos, sin = cos_sin.chunk(2, dim=-1)

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_size)
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        query_rot = _apply_rotary_emb_torch(query_rot, cos, sin,
                                            self.is_neox_style)
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        # key may be None in some cases, e.g. cross-layer KV sharing
        if key is not None:
            key_shape = key.shape
            key = key.view(num_tokens, -1, self.head_size)
            key_rot = key[..., :self.rotary_dim]
            key_pass = key[..., self.rotary_dim:]
            key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
                                              self.is_neox_style)
            key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

    def forward_cuda(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        from vllm import _custom_ops as ops

        # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
        # is expensive, so avoid calling it if possible
        if self.cos_sin_cache.device != query.device or \
            self.cos_sin_cache.dtype != query.dtype:
            self.cos_sin_cache = self.cos_sin_cache.to(query.device,
                                                       dtype=query.dtype)

        # ops.rotary_embedding()/batched_rotary_embedding()
        # are in-place operations that update the query and key tensors.
        if offsets is not None:
            ops.batched_rotary_embedding(positions, query, key, self.head_size,
                                         self.cos_sin_cache,
                                         self.is_neox_style, self.rotary_dim,
                                         offsets)
        else:
            ops.rotary_embedding(positions, query, key, self.head_size,
                                 self.cos_sin_cache, self.is_neox_style)
        return query, key

    def forward_xpu(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        from vllm._ipex_ops import ipex_ops as ops

        self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
                                                   dtype=query.dtype)
        # ops.rotary_embedding()/batched_rotary_embedding()
        # are in-place operations that update the query and key tensors.
        if key is None:
            # XPU kernel doesn't support key=None so fall back to native impl
            # TODO(sarckk): add support for optional key in
            # ipex.llm.functional.rotary_embedding_batched
            return self.forward_native(positions, query, key, offsets)
        else:
            if offsets is not None:
                ops.batched_rotary_embedding(positions, query, key,
                                             self.head_size,
                                             self.cos_sin_cache,
                                             self.is_neox_style,
                                             self.rotary_dim, offsets)
            else:
                ops.rotary_embedding(positions, query, key, self.head_size,
                                     self.cos_sin_cache, self.is_neox_style)
        return query, key

    def forward_hpu(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        from habana_frameworks.torch.hpex.kernels import (
            RotaryPosEmbeddingMode, apply_rotary_pos_emb)
        if offsets is not None:
            offsets = offsets.view(positions.shape[0], -1)
            positions = positions + offsets
        positions = positions.flatten()
        num_tokens = positions.shape[0]
        cos_sin = self.cos_sin_cache.index_select(0, positions).view(
            num_tokens, 1, -1)
        cos, sin = cos_sin.chunk(2, dim=-1)
        # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
        # to query hidden dimension, so the original tensors need to be
        # expanded
        # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
        # and expansion of cos/sin tensors via concatenation
        # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
        # and expansion of cos/sin tensors via repeat_interleave
        rope_mode: RotaryPosEmbeddingMode
        if self.is_neox_style:
            rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
            cos = torch.cat((cos, cos), dim=-1)
            sin = torch.cat((sin, sin), dim=-1)
        else:
            rope_mode = RotaryPosEmbeddingMode.PAIRWISE
            sin = torch.repeat_interleave(sin,
                                          2,
                                          dim=-1,
                                          output_size=cos_sin.shape[-1])
            cos = torch.repeat_interleave(cos,
                                          2,
                                          dim=-1,
                                          output_size=cos_sin.shape[-1])

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_size)
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0,
                                         rope_mode)
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        if key is not None:
            key_shape = key.shape
            key = key.view(num_tokens, -1, self.head_size)
            key_rot = key[..., :self.rotary_dim]
            key_pass = key[..., self.rotary_dim:]
            key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0,
                                           rope_mode)
            key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

    def forward_neuron(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:

        def _apply_rotary_emb_neuron(
            x: torch.Tensor,
            cos: torch.Tensor,
            sin: torch.Tensor,
            is_neox_style: bool,
        ) -> torch.Tensor:
            cos = cos.unsqueeze(-2).to(x.dtype)
            sin = sin.unsqueeze(-2).to(x.dtype)
            if is_neox_style:
                x1, x2 = torch.chunk(x, 2, dim=-1)
            else:
                # x1 = x[..., ::2]

                # x2 = x[..., 1::2]
                d = x.shape[-1] // 2
                x_reshaped = x.view(-1, x.shape[-1])
                x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d)
                x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d)
            o1 = x1 * cos - x2 * sin
            o2 = x2 * cos + x1 * sin
            if is_neox_style:
                return torch.cat((o1, o2), dim=-1)
            else:
                return torch.stack((o1, o2), dim=-1).flatten(-2)

        if offsets is not None:
            positions = positions + offsets

        self.cos_sin_cache = self.cos_sin_cache.to(query.device,
                                                   dtype=query.dtype)

        positions = positions.flatten()
        num_tokens = positions.shape[0]
        cos_sin = self.cos_sin_cache.index_select(0, positions)
        cos, sin = cos_sin.chunk(2, dim=-1)

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_size)
        if key is not None:
            key_shape = key.shape
            key = key.view(num_tokens, -1, self.head_size)

        if self.rotary_dim == self.head_size:
            query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
            query = query.reshape(query_shape)
            if key is not None:
                key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
                key = key.reshape(key_shape)
        else:
            head_size = query.shape[-1]
            query_reshaped = query.view(-1, head_size)
            query_pass = query_reshaped[:, self.rotary_dim:].view(
                *query.shape[:-1], head_size - self.rotary_dim)
            query_rot = query_reshaped[:, :self.rotary_dim].view(
                *query.shape[:-1], self.rotary_dim)
            query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin,
                                                 self.is_neox_style)
            query = torch.cat((query_rot, query_pass),
                              dim=-1).reshape(query_shape)

            if key is not None:
                key_reshaped = key.view(-1, head_size)
                key_pass = key_reshaped[:, self.rotary_dim:].view(
                    *key.shape[:-1], head_size - self.rotary_dim)
                key_rot = key_reshaped[:, :self.rotary_dim].view(
                    *key.shape[:-1], self.rotary_dim)
                key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
                                                   self.is_neox_style)
                key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

    def extra_repr(self) -> str:
        s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
        s += f", max_position_embeddings={self.max_position_embeddings}"
        s += f", base={self.base}, is_neox_style={self.is_neox_style}"
        return s

base instance-attribute

base = base

cos_sin_cache instance-attribute

cos_sin_cache: Tensor

dtype instance-attribute

dtype = dtype

head_size instance-attribute

head_size = head_size

is_neox_style instance-attribute

is_neox_style = is_neox_style

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

rotary_dim instance-attribute

rotary_dim = rotary_dim

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
) -> None:
    super().__init__()
    self.head_size = head_size
    self.rotary_dim = rotary_dim
    self.max_position_embeddings = max_position_embeddings
    self.base = base
    self.is_neox_style = is_neox_style
    self.dtype = dtype

    cache = self._compute_cos_sin_cache()
    cache = cache.to(dtype)
    self.cos_sin_cache: torch.Tensor
    self.register_buffer("cos_sin_cache", cache, persistent=False)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor

Compute the cos and sin cache.

Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    """Compute the cos and sin cache."""
    inv_freq = self._compute_inv_freq(self.base)
    t = torch.arange(self.max_position_embeddings, dtype=torch.float)

    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = freqs.cos()
    sin = freqs.sin()
    cache = torch.cat((cos, sin), dim=-1)
    return cache

_compute_inv_freq

_compute_inv_freq(base: float) -> Tensor

Compute the inverse frequency.

Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_inv_freq(self, base: float) -> torch.Tensor:
    """Compute the inverse frequency."""
    # NOTE(woosuk): To exactly match the HF implementation, we need to
    # use CPU to compute the cache and then move it to GPU. However, we
    # create the cache on GPU for faster initialization. This may cause
    # a slight numerical difference between the HF implementation and ours.
    inv_freq = 1.0 / (base**(torch.arange(
        0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
    return inv_freq

extra_repr

extra_repr() -> str
Source code in vllm/model_executor/layers/rotary_embedding.py
def extra_repr(self) -> str:
    s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
    s += f", max_position_embeddings={self.max_position_embeddings}"
    s += f", base={self.base}, is_neox_style={self.is_neox_style}"
    return s

forward_cuda

forward_cuda(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding.py
def forward_cuda(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    from vllm import _custom_ops as ops

    # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
    # is expensive, so avoid calling it if possible
    if self.cos_sin_cache.device != query.device or \
        self.cos_sin_cache.dtype != query.dtype:
        self.cos_sin_cache = self.cos_sin_cache.to(query.device,
                                                   dtype=query.dtype)

    # ops.rotary_embedding()/batched_rotary_embedding()
    # are in-place operations that update the query and key tensors.
    if offsets is not None:
        ops.batched_rotary_embedding(positions, query, key, self.head_size,
                                     self.cos_sin_cache,
                                     self.is_neox_style, self.rotary_dim,
                                     offsets)
    else:
        ops.rotary_embedding(positions, query, key, self.head_size,
                             self.cos_sin_cache, self.is_neox_style)
    return query, key

forward_hpu

forward_hpu(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding.py
def forward_hpu(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    from habana_frameworks.torch.hpex.kernels import (
        RotaryPosEmbeddingMode, apply_rotary_pos_emb)
    if offsets is not None:
        offsets = offsets.view(positions.shape[0], -1)
        positions = positions + offsets
    positions = positions.flatten()
    num_tokens = positions.shape[0]
    cos_sin = self.cos_sin_cache.index_select(0, positions).view(
        num_tokens, 1, -1)
    cos, sin = cos_sin.chunk(2, dim=-1)
    # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
    # to query hidden dimension, so the original tensors need to be
    # expanded
    # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
    # and expansion of cos/sin tensors via concatenation
    # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
    # and expansion of cos/sin tensors via repeat_interleave
    rope_mode: RotaryPosEmbeddingMode
    if self.is_neox_style:
        rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
        cos = torch.cat((cos, cos), dim=-1)
        sin = torch.cat((sin, sin), dim=-1)
    else:
        rope_mode = RotaryPosEmbeddingMode.PAIRWISE
        sin = torch.repeat_interleave(sin,
                                      2,
                                      dim=-1,
                                      output_size=cos_sin.shape[-1])
        cos = torch.repeat_interleave(cos,
                                      2,
                                      dim=-1,
                                      output_size=cos_sin.shape[-1])

    query_shape = query.shape
    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., :self.rotary_dim]
    query_pass = query[..., self.rotary_dim:]
    query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0,
                                     rope_mode)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    if key is not None:
        key_shape = key.shape
        key = key.view(num_tokens, -1, self.head_size)
        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
        key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0,
                                       rope_mode)
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

forward_native

forward_native(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]

A PyTorch-native implementation of forward().

Source code in vllm/model_executor/layers/rotary_embedding.py
def forward_native(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    """A PyTorch-native implementation of forward()."""
    if offsets is not None:
        positions = positions + offsets
    positions = positions.flatten()
    num_tokens = positions.shape[0]
    cos_sin = self.cos_sin_cache.index_select(0, positions)
    cos, sin = cos_sin.chunk(2, dim=-1)

    query_shape = query.shape
    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., :self.rotary_dim]
    query_pass = query[..., self.rotary_dim:]
    query_rot = _apply_rotary_emb_torch(query_rot, cos, sin,
                                        self.is_neox_style)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    # key may be None in some cases, e.g. cross-layer KV sharing
    if key is not None:
        key_shape = key.shape
        key = key.view(num_tokens, -1, self.head_size)
        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
        key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
                                          self.is_neox_style)
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

forward_neuron

forward_neuron(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding.py
def forward_neuron(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:

    def _apply_rotary_emb_neuron(
        x: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        is_neox_style: bool,
    ) -> torch.Tensor:
        cos = cos.unsqueeze(-2).to(x.dtype)
        sin = sin.unsqueeze(-2).to(x.dtype)
        if is_neox_style:
            x1, x2 = torch.chunk(x, 2, dim=-1)
        else:
            # x1 = x[..., ::2]

            # x2 = x[..., 1::2]
            d = x.shape[-1] // 2
            x_reshaped = x.view(-1, x.shape[-1])
            x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d)
            x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d)
        o1 = x1 * cos - x2 * sin
        o2 = x2 * cos + x1 * sin
        if is_neox_style:
            return torch.cat((o1, o2), dim=-1)
        else:
            return torch.stack((o1, o2), dim=-1).flatten(-2)

    if offsets is not None:
        positions = positions + offsets

    self.cos_sin_cache = self.cos_sin_cache.to(query.device,
                                               dtype=query.dtype)

    positions = positions.flatten()
    num_tokens = positions.shape[0]
    cos_sin = self.cos_sin_cache.index_select(0, positions)
    cos, sin = cos_sin.chunk(2, dim=-1)

    query_shape = query.shape
    query = query.view(num_tokens, -1, self.head_size)
    if key is not None:
        key_shape = key.shape
        key = key.view(num_tokens, -1, self.head_size)

    if self.rotary_dim == self.head_size:
        query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
        query = query.reshape(query_shape)
        if key is not None:
            key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
            key = key.reshape(key_shape)
    else:
        head_size = query.shape[-1]
        query_reshaped = query.view(-1, head_size)
        query_pass = query_reshaped[:, self.rotary_dim:].view(
            *query.shape[:-1], head_size - self.rotary_dim)
        query_rot = query_reshaped[:, :self.rotary_dim].view(
            *query.shape[:-1], self.rotary_dim)
        query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin,
                                             self.is_neox_style)
        query = torch.cat((query_rot, query_pass),
                          dim=-1).reshape(query_shape)

        if key is not None:
            key_reshaped = key.view(-1, head_size)
            key_pass = key_reshaped[:, self.rotary_dim:].view(
                *key.shape[:-1], head_size - self.rotary_dim)
            key_rot = key_reshaped[:, :self.rotary_dim].view(
                *key.shape[:-1], self.rotary_dim)
            key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
                                               self.is_neox_style)
            key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

forward_xpu

forward_xpu(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding.py
def forward_xpu(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    from vllm._ipex_ops import ipex_ops as ops

    self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
                                               dtype=query.dtype)
    # ops.rotary_embedding()/batched_rotary_embedding()
    # are in-place operations that update the query and key tensors.
    if key is None:
        # XPU kernel doesn't support key=None so fall back to native impl
        # TODO(sarckk): add support for optional key in
        # ipex.llm.functional.rotary_embedding_batched
        return self.forward_native(positions, query, key, offsets)
    else:
        if offsets is not None:
            ops.batched_rotary_embedding(positions, query, key,
                                         self.head_size,
                                         self.cos_sin_cache,
                                         self.is_neox_style,
                                         self.rotary_dim, offsets)
        else:
            ops.rotary_embedding(positions, query, key, self.head_size,
                                 self.cos_sin_cache, self.is_neox_style)
    return query, key

YaRNScalingRotaryEmbedding

Bases: RotaryEmbedding

RotaryEmbedding extended with YaRN method.

Credits to Peng et al. github.com/jquesnelle/yarn

Source code in vllm/model_executor/layers/rotary_embedding.py
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with YaRN method.

    Credits to Peng et al. github.com/jquesnelle/yarn
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        scaling_factor: float,
        dtype: torch.dtype,
        *,
        extrapolation_factor: float = 1,
        attn_factor: float = 1,
        beta_fast: int = 32,
        beta_slow: int = 1,
    ) -> None:
        self.scaling_factor = scaling_factor
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        # Get n-d magnitude scaling corrected for interpolation
        self.mscale = float(
            _yarn_get_mscale(self.scaling_factor) * attn_factor)
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
        pos_freqs = self.base**(
            torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
            self.rotary_dim)
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

        low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
                                                self.rotary_dim, self.base,
                                                self.max_position_embeddings)
        # Get n-d rotational scaling corrected for extrapolation
        inv_freq_mask = (1 - _yarn_linear_ramp_mask(
            low, high, self.rotary_dim // 2,
            dtype=torch.float)) * self.extrapolation_factor
        inv_freq = inv_freq_interpolation * (
            1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.scaling_factor)
        t = torch.arange(self.max_position_embeddings * self.scaling_factor,
                         dtype=torch.float32)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = (freqs.cos() * self.mscale)
        sin = (freqs.sin() * self.mscale)
        cache = torch.cat((cos, sin), dim=-1)
        return cache

attn_factor instance-attribute

attn_factor = attn_factor

beta_fast instance-attribute

beta_fast = beta_fast

beta_slow instance-attribute

beta_slow = beta_slow

extrapolation_factor instance-attribute

extrapolation_factor = extrapolation_factor

mscale instance-attribute

mscale = float(
    _yarn_get_mscale(scaling_factor) * attn_factor
)

scaling_factor instance-attribute

scaling_factor = scaling_factor

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: dtype,
    *,
    extrapolation_factor: float = 1,
    attn_factor: float = 1,
    beta_fast: int = 32,
    beta_slow: int = 1,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: torch.dtype,
    *,
    extrapolation_factor: float = 1,
    attn_factor: float = 1,
    beta_fast: int = 32,
    beta_slow: int = 1,
) -> None:
    self.scaling_factor = scaling_factor
    self.extrapolation_factor = extrapolation_factor
    self.attn_factor = attn_factor
    self.beta_fast = beta_fast
    self.beta_slow = beta_slow
    # Get n-d magnitude scaling corrected for interpolation
    self.mscale = float(
        _yarn_get_mscale(self.scaling_factor) * attn_factor)
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(self.scaling_factor)
    t = torch.arange(self.max_position_embeddings * self.scaling_factor,
                     dtype=torch.float32)
    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = (freqs.cos() * self.mscale)
    sin = (freqs.sin() * self.mscale)
    cache = torch.cat((cos, sin), dim=-1)
    return cache

_compute_inv_freq

_compute_inv_freq(scaling_factor: float) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
    pos_freqs = self.base**(
        torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
        self.rotary_dim)
    inv_freq_extrapolation = 1.0 / pos_freqs
    inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

    low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
                                            self.rotary_dim, self.base,
                                            self.max_position_embeddings)
    # Get n-d rotational scaling corrected for extrapolation
    inv_freq_mask = (1 - _yarn_linear_ramp_mask(
        low, high, self.rotary_dim // 2,
        dtype=torch.float)) * self.extrapolation_factor
    inv_freq = inv_freq_interpolation * (
        1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
    return inv_freq

_apply_rotary_emb

_apply_rotary_emb(
    x: Tensor, cos: Tensor, sin: Tensor, is_neox_style: bool
) -> Tensor

Parameters:

Name Type Description Default
x Tensor

[num_tokens, num_heads, head_size]

required
cos Tensor

[num_tokens, head_size // 2]

required
sin Tensor

[num_tokens, head_size // 2]

required
is_neox_style bool

Whether to use the Neox-style or GPT-J-style rotary positional embeddings.

required
Source code in vllm/model_executor/layers/rotary_embedding.py
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
                      is_neox_style: bool) -> torch.Tensor:
    """
    Args:
        x: [num_tokens, num_heads, head_size]
        cos: [num_tokens, head_size // 2]
        sin: [num_tokens, head_size // 2]
        is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
            positional embeddings.
    """
    if current_platform.is_cuda():
        return apply_rotary_emb(x.unsqueeze(0), cos, sin,
                                not is_neox_style).squeeze(0)
    else:
        return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)

_apply_rotary_emb_torch

_apply_rotary_emb_torch(
    x: Tensor, cos: Tensor, sin: Tensor, is_neox_style: bool
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _apply_rotary_emb_torch(
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    is_neox_style: bool,
) -> torch.Tensor:
    cos = cos.unsqueeze(-2).to(x.dtype)
    sin = sin.unsqueeze(-2).to(x.dtype)
    if is_neox_style:
        x1, x2 = torch.chunk(x, 2, dim=-1)
    else:
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
    o1 = x1 * cos - x2 * sin
    o2 = x2 * cos + x1 * sin
    if is_neox_style:
        return torch.cat((o1, o2), dim=-1)
    else:
        return torch.stack((o1, o2), dim=-1).flatten(-2)

_rotate_gptj

_rotate_gptj(x: Tensor) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)

_rotate_neox

_rotate_neox(x: Tensor) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)

_yarn_find_correction_dim

_yarn_find_correction_dim(
    num_rotations: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> float
Source code in vllm/model_executor/layers/rotary_embedding.py
def _yarn_find_correction_dim(num_rotations: int,
                              dim: int,
                              base: float = 10000,
                              max_position_embeddings: int = 2048) -> float:
    return (dim * math.log(max_position_embeddings /
                           (num_rotations * 2 * math.pi))) / (2 *
                                                              math.log(base))

_yarn_find_correction_range

_yarn_find_correction_range(
    low_rot: int,
    high_rot: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> tuple[int, int]
Source code in vllm/model_executor/layers/rotary_embedding.py
def _yarn_find_correction_range(
        low_rot: int,
        high_rot: int,
        dim: int,
        base: float = 10000,
        max_position_embeddings: int = 2048) -> tuple[int, int]:
    low = math.floor(
        _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
    high = math.ceil(
        _yarn_find_correction_dim(high_rot, dim, base,
                                  max_position_embeddings))
    return max(low, 0), min(high, dim - 1)  # Clamp values just in case

_yarn_get_mscale

_yarn_get_mscale(scale: float = 1) -> float
Source code in vllm/model_executor/layers/rotary_embedding.py
def _yarn_get_mscale(scale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * math.log(scale) + 1.0

_yarn_linear_ramp_mask

_yarn_linear_ramp_mask(
    low: float, high: float, dim: int, dtype: dtype
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding.py
def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
                           dtype: torch.dtype) -> torch.Tensor:
    if low == high:
        high += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func

get_rope

get_rope(
    head_size: int,
    rotary_dim: int,
    max_position: int,
    base: float,
    is_neox_style: bool = True,
    rope_scaling: Optional[dict[str, Any]] = None,
    dtype: Optional[dtype] = None,
    partial_rotary_factor: float = 1.0,
    dual_chunk_attention_config: Optional[
        dict[str, Any]
    ] = None,
) -> RotaryEmbedding
Source code in vllm/model_executor/layers/rotary_embedding.py
def get_rope(
    head_size: int,
    rotary_dim: int,
    max_position: int,
    base: float,
    is_neox_style: bool = True,
    rope_scaling: Optional[dict[str, Any]] = None,
    dtype: Optional[torch.dtype] = None,
    partial_rotary_factor: float = 1.0,
    dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> RotaryEmbedding:
    if dtype is None:
        dtype = torch.get_default_dtype()
    if rope_scaling is not None:
        # Transforms every value that is a list into a tuple for caching calls
        rope_scaling_tuple = {
            k: tuple(v) if isinstance(v, list) else v
            for k, v in rope_scaling.items()
        }
        rope_scaling_args = tuple(rope_scaling_tuple.items())
    else:
        rope_scaling_args = None

    if dual_chunk_attention_config is not None:
        dual_chunk_attention_tuple = {
            k: tuple(v) if isinstance(v, list) else v
            for k, v in dual_chunk_attention_config.items()
            if k != "sparse_attention_config"
        }
        dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
    else:
        dual_chunk_attention_args = None

    if partial_rotary_factor < 1.0:
        rotary_dim = int(rotary_dim * partial_rotary_factor)
    key = (head_size, rotary_dim, max_position, base, is_neox_style,
           rope_scaling_args, dual_chunk_attention_args, dtype)
    if key in _ROPE_DICT:
        return _ROPE_DICT[key]

    if dual_chunk_attention_config is not None:
        extra_kwargs = {
            k: v
            for k, v in dual_chunk_attention_config.items()
            if k in ("chunk_size", "local_size")
        }
        rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
                                              max_position, base,
                                              is_neox_style, dtype,
                                              **extra_kwargs)
    elif not rope_scaling:
        rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
                                     is_neox_style, dtype)
    else:
        scaling_type = rope_scaling["rope_type"]

        if scaling_type == "llama3":
            scaling_factor = rope_scaling["factor"]
            low_freq_factor = rope_scaling["low_freq_factor"]
            high_freq_factor = rope_scaling["high_freq_factor"]
            original_max_position = rope_scaling[
                "original_max_position_embeddings"]
            rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
                                               max_position, base,
                                               is_neox_style, dtype,
                                               scaling_factor, low_freq_factor,
                                               high_freq_factor,
                                               original_max_position)
        elif scaling_type == "mllama4":
            rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
                                                     max_position, base,
                                                     is_neox_style, dtype)
        elif scaling_type == "default":
            if "mrope_section" in rope_scaling:
                rotary_emb = MRotaryEmbedding(
                    head_size,
                    rotary_dim,
                    max_position,
                    base,
                    is_neox_style,
                    dtype,
                    mrope_section=rope_scaling["mrope_section"],
                )
            else:
                rotary_emb = RotaryEmbedding(
                    head_size,
                    rotary_dim,
                    max_position,
                    base,
                    is_neox_style,
                    dtype,
                )
        elif scaling_type == "linear":
            scaling_factor = rope_scaling["factor"]
            rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
                                                      max_position, base,
                                                      is_neox_style,
                                                      scaling_factor, dtype)
        elif scaling_type == "ntk":
            scaling_factor = rope_scaling["factor"]
            mixed_b = rope_scaling.get('mixed_b', None)
            rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
                                                   max_position, base,
                                                   is_neox_style,
                                                   scaling_factor, dtype,
                                                   mixed_b)
        elif scaling_type == "dynamic":
            if "alpha" in rope_scaling:
                scaling_alpha = rope_scaling["alpha"]
                rotary_emb = DynamicNTKAlphaRotaryEmbedding(
                    head_size, rotary_dim, max_position, base, is_neox_style,
                    scaling_alpha, dtype)
            elif "factor" in rope_scaling:
                scaling_factor = rope_scaling["factor"]
                rotary_emb = DynamicNTKScalingRotaryEmbedding(
                    head_size, rotary_dim, max_position, base, is_neox_style,
                    scaling_factor, dtype)
            else:
                raise ValueError("Dynamic rope scaling must contain either "
                                 "'alpha' or 'factor' field")
        elif scaling_type == "yarn":
            scaling_factor = rope_scaling["factor"]
            original_max_position = rope_scaling[
                "original_max_position_embeddings"]
            extra_kwargs = {
                k: v
                for k, v in rope_scaling.items()
                if k in ("extrapolation_factor", "attn_factor", "beta_fast",
                         "beta_slow")
            }
            rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
                                                    original_max_position,
                                                    base, is_neox_style,
                                                    scaling_factor, dtype,
                                                    **extra_kwargs)
        elif scaling_type == "deepseek_yarn":
            scaling_factor = rope_scaling["factor"]
            original_max_position = rope_scaling[
                "original_max_position_embeddings"]
            # assert max_position == original_max_position * scaling_factor
            extra_kwargs = {
                k: v
                for k, v in rope_scaling.items()
                if k in ("extrapolation_factor", "attn_factor", "beta_fast",
                         "beta_slow", "mscale", "mscale_all_dim")
            }
            rotary_emb = DeepseekScalingRotaryEmbedding(
                head_size, rotary_dim, original_max_position, base,
                is_neox_style, scaling_factor, dtype, **extra_kwargs)
        elif scaling_type == "longrope":
            short_factor = rope_scaling["short_factor"]
            long_factor = rope_scaling["long_factor"]
            original_max_position = rope_scaling[
                "original_max_position_embeddings"]
            extra_kwargs = {
                k: v
                for k, v in rope_scaling.items()
                if k in ("short_mscale", "long_mscale")
            }
            rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
                head_size, rotary_dim, max_position, original_max_position,
                base, is_neox_style, dtype, short_factor, long_factor,
                **extra_kwargs)
        else:
            raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
    _ROPE_DICT[key] = rotary_emb
    return rotary_emb

yarn_get_mscale

yarn_get_mscale(
    scale: float = 1, mscale: float = 1
) -> float
Source code in vllm/model_executor/layers/rotary_embedding.py
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0