class SarvamMLAModel(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.embed_dim = config.hidden_size
self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
if get_pp_group().is_first_rank or (
self.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
self.embedding_dropout = torch.nn.Dropout(
getattr(config, "embedding_dropout", 0.0)
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: SarvamMLABlock(
vllm_config=vllm_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
hidden_states = self.embedding_dropout(hidden_states)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
hidden_states,
positions,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
"""Load weights with stacked gate+up and MoE expert remapping."""
weights = _normalized_weights(weights)
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
new_name = name.replace(weight_name, param_name)
if new_name.endswith(".bias") and new_name not in params_dict:
continue
if new_name not in params_dict:
continue
if is_pp_missing_parameter(new_name, self):
continue
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(new_name)
break
else:
mapped = False
for (
param_name,
weight_name,
expert_id,
shard_id,
) in expert_params_mapping:
if weight_name not in name:
continue
new_name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(new_name, self):
continue
if new_name not in params_dict:
continue
param = params_dict[new_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(new_name)
mapped = True
break
if mapped:
continue
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params