Medusa proposer class for generating token sequences
Source code in vllm/v1/spec_decode/medusa.py
| class MedusaProposer:
"""
Medusa proposer class for generating token sequences
"""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
# Save config parameters
self.vllm_config = vllm_config
self.device = device
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.hidden_size = vllm_config.speculative_config.\
draft_model_config.get_hidden_size(
)
self.dtype = vllm_config.model_config.dtype
def propose(
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks, None)
# Get draft tokens and transpose the result
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
return [list(row) for row in zip(*draft_tokens)]
def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag
with set_model_tag("medusa_head"):
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
self.model(hidden_states)
|
device
instance-attribute
hidden_size
instance-attribute
hidden_size = get_hidden_size()
max_num_tokens
instance-attribute
max_num_tokens = max_num_batched_tokens
vllm_config
instance-attribute
vllm_config = vllm_config
__init__
Source code in vllm/v1/spec_decode/medusa.py
| def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
# Save config parameters
self.vllm_config = vllm_config
self.device = device
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.hidden_size = vllm_config.speculative_config.\
draft_model_config.get_hidden_size(
)
self.dtype = vllm_config.model_config.dtype
|
dummy_run
dummy_run(num_tokens: int) -> None
Source code in vllm/v1/spec_decode/medusa.py
| @torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
self.model(hidden_states)
|
load_model
load_model(target_model: Module) -> None
Source code in vllm/v1/spec_decode/medusa.py
| def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag
with set_model_tag("medusa_head"):
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
|
propose
Source code in vllm/v1/spec_decode/medusa.py
| def propose(
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks, None)
# Get draft tokens and transpose the result
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
return [list(row) for row in zip(*draft_tokens)]
|