Skip to content

vllm.models.deepseek_v4.common.ops.save_partial_states

save_partial_states

save_partial_states(
    kv: Tensor,
    score: Tensor,
    ape: Tensor,
    positions: Tensor,
    state_cache: Tensor,
    slot_mapping: Tensor,
    block_size: int,
    state_width: int,
    compress_ratio: int,
    pdl_kwargs: dict | None = None,
) -> None

Write packed [kv, score+ape] partial states into the compressor cache.

One program per token; pads (slot_id == -1) are skipped.

Source code in vllm/models/deepseek_v4/common/ops/save_partial_states.py
def save_partial_states(
    kv: torch.Tensor,
    score: torch.Tensor,
    ape: torch.Tensor,
    positions: torch.Tensor,
    state_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    block_size: int,
    state_width: int,
    compress_ratio: int,
    pdl_kwargs: dict | None = None,
) -> None:
    """Write packed [kv, score+ape] partial states into the compressor cache.

    One program per token; pads (slot_id == -1) are skipped.
    """
    num_actual = slot_mapping.shape[0]
    head_size = kv.shape[-1]
    _save_partial_states_kernel[(num_actual,)](
        kv,
        kv.stride(0),
        score,
        score.stride(0),
        ape,
        ape.stride(0),
        positions,
        state_cache,
        state_cache.stride(0),
        state_cache.stride(1),
        slot_mapping,
        block_size,
        HEAD_SIZE=head_size,
        TRITON_BLOCK_SIZE=triton.next_power_of_2(head_size),
        STATE_WIDTH=state_width,
        COMPRESS_RATIO=compress_ratio,
        **(pdl_kwargs or {}),
    )