Skip to content

vllm.model_executor.models.mlp_speculator

SQRT2 module-attribute

SQRT2 = 2 ** 0.5

MLPSpeculator

Bases: Module

An implementation of the speculative models introduced in "Accelerating Production LLMs with Combined Token/Embedding Speculators" https://arxiv.org/pdf/2404.19124

Trained speculators of this type are available on HF hub at: https://huggingface.co/ibm-ai-platform and https://huggingface.co/ibm-granite

Source code in vllm/model_executor/models/mlp_speculator.py
class MLPSpeculator(nn.Module):
    """
    An implementation of the speculative models introduced in
    "Accelerating Production LLMs with Combined Token/Embedding
    Speculators"
    https://arxiv.org/pdf/2404.19124

    Trained speculators of this type are available on HF hub at:
    https://huggingface.co/ibm-ai-platform and https://huggingface.co/ibm-granite
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.n_predict = config.n_predict
        self.vocab_size = config.vocab_size
        self.emb_dim = config.emb_dim
        self.inner_dim = config.inner_dim if config.inner_dim != 0 \
            else config.emb_dim

        self.max_speculative_tokens = config.num_lookahead_tokens

        self.tie_weights = config.tie_weights
        self.scale_input = config.scale_input

        if self.tie_weights:
            assert (
                self.n_predict > 1
            ), "You cannot tie weights between stages when only 1 exists"
            embedding = VocabParallelEmbedding(
                config.vocab_size,
                self.inner_dim,
                org_num_embeddings=config.vocab_size)
            self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)

            # the initial projection from the base model may
            # have a different size, so that stays separate.
            proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
            proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
            self.proj = nn.ModuleList([proj_first] + [proj_tied] *
                                      (self.max_speculative_tokens - 1))

            head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
            self.head = nn.ModuleList([head] * self.max_speculative_tokens)

            ln = MLPSpeculatorLayerNorm(self.inner_dim,
                                        elementwise_scale_and_shift=True)
            self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)

        else:
            self.emb = nn.ModuleList([
                VocabParallelEmbedding(config.vocab_size,
                                       self.inner_dim,
                                       org_num_embeddings=config.vocab_size)
                for _ in range(self.max_speculative_tokens)
            ])

            self.proj = nn.ModuleList([
                nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
                          self.inner_dim,
                          bias=False)
                for i in range(self.max_speculative_tokens)
            ])

            self.head = nn.ModuleList([
                ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
                for _ in range(self.max_speculative_tokens)
            ])
            self.ln = nn.ModuleList([
                MLPSpeculatorLayerNorm(self.inner_dim,
                                       elementwise_scale_and_shift=True)
                for _ in range(self.max_speculative_tokens)
            ])
        if self.scale_input:
            self.ln0 = MLPSpeculatorLayerNorm(
                self.emb_dim, elementwise_scale_and_shift=False)

        self.state_weight = 0.5**(0.5 / config.n_predict)
        self.emb_weight = math.sqrt(
            (1 - self.state_weight**2) * (self.inner_dim / 2))
        self.activation = nn.GELU()
        self.config = config
        self.logits_processor = LogitsProcessor(config.vocab_size,
                                                config.vocab_size, 1.0)
        self.sampler = get_sampler()

    def generate_proposals(
        self,
        input_ids: torch.Tensor,
        previous_hidden_states: torch.Tensor,
        num_predict_tokens: int,
        sampling_metadata: SamplingMetadata,
    ) -> list[SamplerOutput]:
        if num_predict_tokens > self.max_speculative_tokens:
            raise ValueError(f"Max speculative tokens for model is "
                             f"{self.max_speculative_tokens}, but "
                             f"{num_predict_tokens} were requested")

        # b x 1 x d
        previous_hidden_states = previous_hidden_states.unsqueeze(1)

        if self.scale_input:
            previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2

        # b x 1
        last_tokens = input_ids.unsqueeze(1)

        next_tokens = []

        for head_index in range(num_predict_tokens):

            # Project and predict
            z = self.emb[head_index](last_tokens)  # b k d
            states = self.proj[head_index](previous_hidden_states)

            # Weighted add of state_weight*state and emb_weight*z
            # Let subsequent LN take care of denominator
            # state_weight is close to 1, so shouldn't be any precision issues
            states.add_(z, alpha=self.emb_weight / self.state_weight)

            states = self.activation(self.ln[head_index](states))  # b k d
            previous_hidden_states = states
            # TODO: not yet supporting top_k_tokens_per_head
            states = states.flatten(0, 1)

            logits = self.logits_processor(self.head[head_index], states,
                                           sampling_metadata)

            output = self.sampler(logits, sampling_metadata)
            last_tokens = output.sampled_token_ids
            next_tokens.append(output)

        return next_tokens

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            name = name.replace("speculator.", "")
            param = params_dict.get(name)
            if param is not None:
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)
        return loaded_params

activation instance-attribute

activation = GELU()

config instance-attribute

config = config

emb instance-attribute

emb = ModuleList([embedding] * max_speculative_tokens)

emb_dim instance-attribute

emb_dim = emb_dim

emb_weight instance-attribute

emb_weight = sqrt(1 - state_weight ** 2 * inner_dim / 2)

head instance-attribute

head = ModuleList([head] * max_speculative_tokens)

inner_dim instance-attribute

inner_dim = inner_dim if inner_dim != 0 else emb_dim

ln instance-attribute

ln = ModuleList([ln] * max_speculative_tokens)

ln0 instance-attribute

ln0 = MLPSpeculatorLayerNorm(
    emb_dim, elementwise_scale_and_shift=False
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    vocab_size, vocab_size, 1.0
)

max_speculative_tokens instance-attribute

max_speculative_tokens = num_lookahead_tokens

n_predict instance-attribute

n_predict = n_predict

proj instance-attribute

proj = ModuleList(
    [proj_first] + [proj_tied] * max_speculative_tokens - 1
)

sampler instance-attribute

sampler = get_sampler()

scale_input instance-attribute

scale_input = scale_input

state_weight instance-attribute

state_weight = 0.5 ** 0.5 / n_predict

tie_weights instance-attribute

tie_weights = tie_weights

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(
    *, vllm_config: VllmConfig, prefix: str = ""
) -> None
Source code in vllm/model_executor/models/mlp_speculator.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
    super().__init__()
    config = vllm_config.model_config.hf_config
    self.n_predict = config.n_predict
    self.vocab_size = config.vocab_size
    self.emb_dim = config.emb_dim
    self.inner_dim = config.inner_dim if config.inner_dim != 0 \
        else config.emb_dim

    self.max_speculative_tokens = config.num_lookahead_tokens

    self.tie_weights = config.tie_weights
    self.scale_input = config.scale_input

    if self.tie_weights:
        assert (
            self.n_predict > 1
        ), "You cannot tie weights between stages when only 1 exists"
        embedding = VocabParallelEmbedding(
            config.vocab_size,
            self.inner_dim,
            org_num_embeddings=config.vocab_size)
        self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)

        # the initial projection from the base model may
        # have a different size, so that stays separate.
        proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
        proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
        self.proj = nn.ModuleList([proj_first] + [proj_tied] *
                                  (self.max_speculative_tokens - 1))

        head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
        self.head = nn.ModuleList([head] * self.max_speculative_tokens)

        ln = MLPSpeculatorLayerNorm(self.inner_dim,
                                    elementwise_scale_and_shift=True)
        self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)

    else:
        self.emb = nn.ModuleList([
            VocabParallelEmbedding(config.vocab_size,
                                   self.inner_dim,
                                   org_num_embeddings=config.vocab_size)
            for _ in range(self.max_speculative_tokens)
        ])

        self.proj = nn.ModuleList([
            nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
                      self.inner_dim,
                      bias=False)
            for i in range(self.max_speculative_tokens)
        ])

        self.head = nn.ModuleList([
            ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
            for _ in range(self.max_speculative_tokens)
        ])
        self.ln = nn.ModuleList([
            MLPSpeculatorLayerNorm(self.inner_dim,
                                   elementwise_scale_and_shift=True)
            for _ in range(self.max_speculative_tokens)
        ])
    if self.scale_input:
        self.ln0 = MLPSpeculatorLayerNorm(
            self.emb_dim, elementwise_scale_and_shift=False)

    self.state_weight = 0.5**(0.5 / config.n_predict)
    self.emb_weight = math.sqrt(
        (1 - self.state_weight**2) * (self.inner_dim / 2))
    self.activation = nn.GELU()
    self.config = config
    self.logits_processor = LogitsProcessor(config.vocab_size,
                                            config.vocab_size, 1.0)
    self.sampler = get_sampler()

generate_proposals

generate_proposals(
    input_ids: Tensor,
    previous_hidden_states: Tensor,
    num_predict_tokens: int,
    sampling_metadata: SamplingMetadata,
) -> list[SamplerOutput]
Source code in vllm/model_executor/models/mlp_speculator.py
def generate_proposals(
    self,
    input_ids: torch.Tensor,
    previous_hidden_states: torch.Tensor,
    num_predict_tokens: int,
    sampling_metadata: SamplingMetadata,
) -> list[SamplerOutput]:
    if num_predict_tokens > self.max_speculative_tokens:
        raise ValueError(f"Max speculative tokens for model is "
                         f"{self.max_speculative_tokens}, but "
                         f"{num_predict_tokens} were requested")

    # b x 1 x d
    previous_hidden_states = previous_hidden_states.unsqueeze(1)

    if self.scale_input:
        previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2

    # b x 1
    last_tokens = input_ids.unsqueeze(1)

    next_tokens = []

    for head_index in range(num_predict_tokens):

        # Project and predict
        z = self.emb[head_index](last_tokens)  # b k d
        states = self.proj[head_index](previous_hidden_states)

        # Weighted add of state_weight*state and emb_weight*z
        # Let subsequent LN take care of denominator
        # state_weight is close to 1, so shouldn't be any precision issues
        states.add_(z, alpha=self.emb_weight / self.state_weight)

        states = self.activation(self.ln[head_index](states))  # b k d
        previous_hidden_states = states
        # TODO: not yet supporting top_k_tokens_per_head
        states = states.flatten(0, 1)

        logits = self.logits_processor(self.head[head_index], states,
                                       sampling_metadata)

        output = self.sampler(logits, sampling_metadata)
        last_tokens = output.sampled_token_ids
        next_tokens.append(output)

    return next_tokens

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/mlp_speculator.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        name = name.replace("speculator.", "")
        param = params_dict.get(name)
        if param is not None:
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
    return loaded_params

MLPSpeculatorLayerNorm

Bases: Module

A L2 normalization implementation ... Args


normalized_shape : int Dimensionality of input data (size of final tensor axis) eps : float Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8). elementwise_scale_and_shift : bool Include a learned scaling and shift term after normalization.

Source code in vllm/model_executor/models/mlp_speculator.py
class MLPSpeculatorLayerNorm(nn.Module):
    """
    A L2 normalization implementation
    ...
    Args
    ----
    normalized_shape : int
        Dimensionality of input data (size of final tensor axis)
    eps : float
        Safety term to prevent division by zero. Make sure the chosen value
         fits in the range of your encoding scheme
         (i.e. fp16 requires eps >= 6e-8).
    elementwise_scale_and_shift : bool
        Include a learned scaling and shift term after normalization.
    """

    def __init__(
        self,
        normalized_shape,
        eps=1e-06,
        elementwise_scale_and_shift=True,
    ):
        super().__init__()
        self.elementwise_scale_and_shift = elementwise_scale_and_shift
        if self.elementwise_scale_and_shift:
            self.weight = nn.Parameter(torch.empty(normalized_shape))
            self.bias = nn.Parameter(torch.empty(normalized_shape))
        self.eps = eps

    def forward(self, x):
        xf = x
        xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
        x = xf.type_as(x)
        if self.elementwise_scale_and_shift:
            x = self.weight * x
            x = x + self.bias
        return x

bias instance-attribute

bias = Parameter(empty(normalized_shape))

elementwise_scale_and_shift instance-attribute

elementwise_scale_and_shift = elementwise_scale_and_shift

eps instance-attribute

eps = eps

weight instance-attribute

weight = Parameter(empty(normalized_shape))

__init__

__init__(
    normalized_shape,
    eps=1e-06,
    elementwise_scale_and_shift=True,
)
Source code in vllm/model_executor/models/mlp_speculator.py
def __init__(
    self,
    normalized_shape,
    eps=1e-06,
    elementwise_scale_and_shift=True,
):
    super().__init__()
    self.elementwise_scale_and_shift = elementwise_scale_and_shift
    if self.elementwise_scale_and_shift:
        self.weight = nn.Parameter(torch.empty(normalized_shape))
        self.bias = nn.Parameter(torch.empty(normalized_shape))
    self.eps = eps

forward

forward(x)
Source code in vllm/model_executor/models/mlp_speculator.py
def forward(self, x):
    xf = x
    xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
    x = xf.type_as(x)
    if self.elementwise_scale_and_shift:
        x = self.weight * x
        x = x + self.bias
    return x