vllm.model_executor.models.gritlm
GritLM
¶
Bases: LlamaForCausalLM
, SupportsV0Only
This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
The class inherits from LlamaForCausalLM and provides a custom pooling
layer.
The main difference between the pooling layer in GritLM and the one in
LlamaForCausalLM is that GritLM ignores the query instruction in the prompt
when pooling the hidden states.
Embedding prompts should be in the following format:
- With instruction: "<|user|>
INSTRUCTION <|embed|> PROMPT". - Without instruction: "<|embed|> PROMPT".
Generation prompts should be in the following format:
- "<|user|>
PROMPT <|assistant|> "
Source code in vllm/model_executor/models/gritlm.py
__init__
¶
__init__(
vllm_config: VllmConfig, prefix: str = "", **kwargs
) -> None
Source code in vllm/model_executor/models/gritlm.py
pooler
¶
pooler(
hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> Optional[PoolerOutput]
GritLMPooler
¶
Bases: Module
Source code in vllm/model_executor/models/gritlm.py
|
|
embed_newline_pattern_ids
instance-attribute
¶
embed_pattern_ids
instance-attribute
¶
token_ids
instance-attribute
¶
token_ids = {
tok: convert_tokens_to_ids([tok])[0]
for tok in [
"<s>",
"▁<",
"<",
"|",
"embed",
">",
"<0x0A>",
"user",
]
}
user_pattern_ids
instance-attribute
¶
__init__
¶
__init__(model_config: ModelConfig)
Source code in vllm/model_executor/models/gritlm.py
_find_array
¶
Find the first occurrence of target in arr starting from start_idx.
Args: arr: The array to search within target: The consecutive subsequence to find start_idx: The starting index to search from
Returns: int: The index of the first occurrence of target in arr.
Source code in vllm/model_executor/models/gritlm.py
_get_instruction_len
¶
Get the length of the instruction in the prompt.
We do a pattern matching to find the instruction in the prompt, and then return the length of the instruction.
The pattern matching is done using integers instead of strings because the prompt is given as a list of token IDs.
Source code in vllm/model_executor/models/gritlm.py
forward
¶
forward(
hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> PoolerOutput
Pool the hidden states by summing the embeddings of non-instruction tokens.