Skip to content

vllm.model_executor.models.gpt_j

Inference-only GPT-J model compatible with HuggingFace weights.

GPTJAttention

Bases: Module

Source code in vllm/model_executor/models/gpt_j.py
class GPTJAttention(nn.Module):

    def __init__(
        self,
        config: GPTJConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        self.qkv_proj = QKVParallelLinear(
            config.hidden_size,
            self.head_size,
            self.total_num_heads,
            bias=False,
            quant_config=quant_config,
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
        assert getattr(config, "rotary", True)
        assert config.rotary_dim % 2 == 0
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.rotary_emb = get_rope(
            self.head_size,
            rotary_dim=config.rotary_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            is_neox_style=False,
        )
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)
        attn_output = self.attn(q, k, v)
        attn_output, _ = self.out_proj(attn_output)
        return attn_output

attn instance-attribute

attn = Attention(
    num_heads,
    head_size,
    scaling,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

head_size instance-attribute

head_size = hidden_size // total_num_heads

hidden_size instance-attribute

hidden_size = hidden_size

num_heads instance-attribute

num_heads = total_num_heads // tp_world_size

out_proj instance-attribute

out_proj = RowParallelLinear(
    hidden_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
)

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_size,
    total_num_heads,
    bias=False,
    quant_config=quant_config,
)

rotary_emb instance-attribute

rotary_emb = get_rope(
    head_size,
    rotary_dim=rotary_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
    is_neox_style=False,
)

total_num_heads instance-attribute

total_num_heads = num_attention_heads

__init__

__init__(
    config: GPTJConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/gpt_j.py
def __init__(
    self,
    config: GPTJConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.total_num_heads = config.num_attention_heads
    self.hidden_size = config.hidden_size
    self.head_size = self.hidden_size // self.total_num_heads

    self.qkv_proj = QKVParallelLinear(
        config.hidden_size,
        self.head_size,
        self.total_num_heads,
        bias=False,
        quant_config=quant_config,
    )
    self.out_proj = RowParallelLinear(
        config.hidden_size,
        config.hidden_size,
        bias=False,
        quant_config=quant_config,
    )

    tp_world_size = get_tensor_model_parallel_world_size()
    assert self.total_num_heads % tp_world_size == 0
    self.num_heads = self.total_num_heads // tp_world_size

    scaling = self.head_size**-0.5
    assert getattr(config, "rotary", True)
    assert config.rotary_dim % 2 == 0
    rope_theta = getattr(config, "rope_theta", 10000)
    max_position_embeddings = getattr(config, "max_position_embeddings",
                                      8192)
    self.rotary_emb = get_rope(
        self.head_size,
        rotary_dim=config.rotary_dim,
        max_position=max_position_embeddings,
        base=rope_theta,
        is_neox_style=False,
    )
    self.attn = Attention(self.num_heads,
                          self.head_size,
                          scaling,
                          cache_config=cache_config,
                          quant_config=quant_config,
                          prefix=f"{prefix}.attn")

forward

forward(
    position_ids: Tensor, hidden_states: Tensor
) -> Tensor
Source code in vllm/model_executor/models/gpt_j.py
def forward(
    self,
    position_ids: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.chunk(chunks=3, dim=-1)
    q, k = self.rotary_emb(position_ids, q, k)
    attn_output = self.attn(q, k, v)
    attn_output, _ = self.out_proj(attn_output)
    return attn_output

GPTJBlock

Bases: Module

Source code in vllm/model_executor/models/gpt_j.py
class GPTJBlock(nn.Module):

    def __init__(
        self,
        config: GPTJConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = GPTJAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
        self.mlp = GPTJMLP(inner_dim, config, quant_config)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states

attn instance-attribute

attn = GPTJAttention(
    config,
    cache_config,
    quant_config,
    prefix=f"{prefix}.attn",
)

ln_1 instance-attribute

ln_1 = LayerNorm(n_embd, eps=layer_norm_epsilon)

mlp instance-attribute

mlp = GPTJMLP(inner_dim, config, quant_config)

__init__

__init__(
    config: GPTJConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/gpt_j.py
def __init__(
    self,
    config: GPTJConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    inner_dim = (4 * config.n_embd
                 if config.n_inner is None else config.n_inner)
    self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
    self.attn = GPTJAttention(config,
                              cache_config,
                              quant_config,
                              prefix=f"{prefix}.attn")
    self.mlp = GPTJMLP(inner_dim, config, quant_config)

forward

forward(
    position_ids: Tensor, hidden_states: Tensor
) -> Tensor
Source code in vllm/model_executor/models/gpt_j.py
def forward(
    self,
    position_ids: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    residual = hidden_states
    hidden_states = self.ln_1(hidden_states)
    attn_output = self.attn(
        position_ids=position_ids,
        hidden_states=hidden_states,
    )
    mlp_output = self.mlp(hidden_states)
    hidden_states = attn_output + mlp_output + residual
    return hidden_states

GPTJForCausalLM

Bases: Module, SupportsPP

Source code in vllm/model_executor/models/gpt_j.py
class GPTJForCausalLM(nn.Module, SupportsPP):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        assert not config.tie_word_embeddings
        self.transformer = GPTJModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.n_embd,
            bias=True,
            quant_config=quant_config,
        )
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata, self.lm_head.bias)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

config instance-attribute

config = config

lm_head instance-attribute

lm_head = ParallelLMHead(
    vocab_size, n_embd, bias=True, quant_config=quant_config
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

quant_config instance-attribute

quant_config = quant_config

transformer instance-attribute

transformer = GPTJModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "transformer"),
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/gpt_j.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    self.config = config
    self.quant_config = quant_config
    assert not config.tie_word_embeddings
    self.transformer = GPTJModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(
                                     prefix, "transformer"))
    self.lm_head = ParallelLMHead(
        config.vocab_size,
        config.n_embd,
        bias=True,
        quant_config=quant_config,
    )
    self.logits_processor = LogitsProcessor(config.vocab_size)
    self.make_empty_intermediate_tensors = (
        self.transformer.make_empty_intermediate_tensors)

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]
Source code in vllm/model_executor/models/gpt_j.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.lm_head, hidden_states,
                                   sampling_metadata, self.lm_head.bias)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/gpt_j.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    hidden_states = self.transformer(input_ids, positions,
                                     intermediate_tensors, inputs_embeds)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gpt_j.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.transformer.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/gpt_j.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self)
    return loader.load_weights(weights)

GPTJMLP

Bases: Module

Source code in vllm/model_executor/models/gpt_j.py
class GPTJMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        hidden_size = config.n_embd
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            quant_config=quant_config,
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
            quant_config=quant_config,
        )
        self.act = get_act_fn(config.activation_function)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states

act instance-attribute

act = get_act_fn(activation_function)

fc_in instance-attribute

fc_in = ColumnParallelLinear(
    hidden_size,
    intermediate_size,
    quant_config=quant_config,
)

fc_out instance-attribute

fc_out = RowParallelLinear(
    intermediate_size,
    hidden_size,
    quant_config=quant_config,
)

__init__

__init__(
    intermediate_size: int,
    config: GPTJConfig,
    quant_config: Optional[QuantizationConfig] = None,
)
Source code in vllm/model_executor/models/gpt_j.py
def __init__(
    self,
    intermediate_size: int,
    config: GPTJConfig,
    quant_config: Optional[QuantizationConfig] = None,
):
    super().__init__()
    hidden_size = config.n_embd
    self.fc_in = ColumnParallelLinear(
        hidden_size,
        intermediate_size,
        quant_config=quant_config,
    )
    self.fc_out = RowParallelLinear(
        intermediate_size,
        hidden_size,
        quant_config=quant_config,
    )
    self.act = get_act_fn(config.activation_function)

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/gpt_j.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    hidden_states, _ = self.fc_in(hidden_states)
    hidden_states = self.act(hidden_states)
    hidden_states, _ = self.fc_out(hidden_states)
    return hidden_states

GPTJModel

Bases: Module

Source code in vllm/model_executor/models/gpt_j.py
@support_torch_compile
class GPTJModel(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.quant_config = quant_config
        self.embed_dim = config.n_embd
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
        self.start_layer, self.end_layer, self.h = make_layers(
            config.n_layer,
            lambda prefix: GPTJBlock(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h",
        )
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors],
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
        else:
            hidden_states = intermediate_tensors["hidden_states"]
        for layer in self.h[self.start_layer:self.end_layer]:
            hidden_states = layer(position_ids, hidden_states)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue

            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config instance-attribute

config = config

embed_dim instance-attribute

embed_dim = n_embd

ln_f instance-attribute

ln_f = LayerNorm(embed_dim, eps=layer_norm_epsilon)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states"], n_embd
    )
)

quant_config instance-attribute

quant_config = quant_config

wte instance-attribute

wte = VocabParallelEmbedding(vocab_size, embed_dim)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/gpt_j.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config = vllm_config.model_config.hf_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config

    self.config = config
    self.quant_config = quant_config
    self.embed_dim = config.n_embd
    self.wte = VocabParallelEmbedding(
        config.vocab_size,
        self.embed_dim,
    )
    self.start_layer, self.end_layer, self.h = make_layers(
        config.n_layer,
        lambda prefix: GPTJBlock(
            config, cache_config, quant_config, prefix=prefix),
        prefix=f"{prefix}.h",
    )
    self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(["hidden_states"],
                                                config.n_embd))

forward

forward(
    input_ids: Tensor,
    position_ids: Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/gpt_j.py
def forward(
    self,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
    else:
        hidden_states = intermediate_tensors["hidden_states"]
    for layer in self.h[self.start_layer:self.end_layer]:
        hidden_states = layer(position_ids, hidden_states)
    if not get_pp_group().is_last_rank:
        return IntermediateTensors({"hidden_states": hidden_states})
    hidden_states = self.ln_f(hidden_states)
    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gpt_j.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.wte(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/gpt_j.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
        ("gate_up_proj", "gate_proj", 0),
        ("gate_up_proj", "up_proj", 1),
    ]
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        if "attn.bias" in name or "attn.masked_bias" in name:
            continue

        if (self.quant_config is not None and
            (scale_name := self.quant_config.get_cache_scale(name))):
            # Loading kv cache quantization scales
            param = params_dict[scale_name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                             loaded_weight[0])
            weight_loader(param, loaded_weight)
            loaded_params.add(scale_name)
            continue

        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            name = maybe_remap_kv_scale_name(name, params_dict)
            if name is None:
                continue
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params