vllm.model_executor.layers.vocab_parallel_embedding
ParallelLMHead
¶
Bases: VocabParallelEmbedding
Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias tensors are padded to make sure they are divisible by the number of model parallel GPUs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_embeddings
|
int
|
vocabulary size. |
required |
embedding_dim
|
int
|
size of hidden state. |
required |
bias
|
bool
|
whether to use bias. |
False
|
params_dtype
|
Optional[dtype]
|
type of the parameters. |
None
|
org_num_embeddings
|
Optional[int]
|
original vocabulary size (without LoRA). |
None
|
padding_size
|
int
|
padding size for the vocabulary. |
DEFAULT_VOCAB_PADDING_SIZE
|
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
bias
instance-attribute
¶
__init__
¶
__init__(
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
)
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
forward
¶
tie_weights
¶
tie_weights(embed_tokens: VocabParallelEmbedding)
Tie the weights with word embeddings.
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
UnquantizedEmbeddingMethod
¶
Bases: QuantizeMethodBase
Unquantized method for embeddings.
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
apply
¶
create_weights
¶
create_weights(
layer: Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: dtype,
**extra_weight_attrs,
)
Create weights for embedding layer.
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
VocabParallelEmbedding
¶
Bases: Module
Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added embeddings are always at the end of TP-sharded tensors. In other words, we shard base embeddings and LoRA embeddings separately (both padded), and place them in the same tensor. In this example, we will have the original vocab size = 1010, added vocab size = 16 and padding to 64. Therefore, the total vocab size with padding will be 1088 (because we first pad 1010 to 1024, add 16, and then pad to 1088). Therefore, the tensor format looks like the following: TP1, rank 0 (no sharding): |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1025 | -1 | ... | -1 | index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0: |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1010 | ... | 1025 | -1 | ... | -1 | index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 | TP2, rank 1: |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_embeddings
|
int
|
vocabulary size. |
required |
embedding_dim
|
int
|
size of hidden state. |
required |
params_dtype
|
Optional[dtype]
|
type of the parameters. |
None
|
org_num_embeddings
|
Optional[int]
|
original vocabulary size (without LoRA). |
None
|
padding_size
|
int
|
padding size for the vocabulary. |
DEFAULT_VOCAB_PADDING_SIZE
|
quant_config
|
Optional[QuantizationConfig]
|
quant config for the layer |
None
|
prefix
|
str
|
full name of the layer in the state dict |
''
|
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 |
|
num_added_embeddings_per_partition
instance-attribute
¶
num_embeddings_padded
instance-attribute
¶
num_embeddings_padded = pad_vocab_size(
org_vocab_size_padded + num_added_embeddings,
padding_size,
)
num_embeddings_per_partition
instance-attribute
¶
num_embeddings_per_partition = divide(
num_embeddings_padded, tp_size
)
num_org_embeddings_per_partition
instance-attribute
¶
org_vocab_size_padded
instance-attribute
¶
org_vocab_size_padded = pad_vocab_size(
org_vocab_size, padding_size
)
shard_indices
instance-attribute
¶
shard_indices = _get_indices(
num_embeddings_padded,
org_vocab_size_padded,
num_embeddings,
org_vocab_size,
tp_rank,
tp_size,
)
__init__
¶
__init__(
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
)
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
_get_indices
classmethod
¶
_get_indices(
vocab_size_padded: int,
org_vocab_size_padded: int,
vocab_size: int,
org_vocab_size: int,
tp_rank: int,
tp_size: int,
) -> VocabParallelEmbeddingShardIndices
Get start and end indices for vocab parallel embedding, following the layout outlined in the class docstring, based on the given tp_rank and tp_size.
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
extra_repr
¶
extra_repr() -> str
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
forward
¶
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
get_sharded_to_full_mapping
¶
Get a mapping that can be used to reindex the gathered logits for sampling.
During sampling, we gather logits from all ranks. The relationship of index->token_id will follow the same format as outlined in the class docstring. However, after the gather, we want to reindex the final logits tensor to map index->token_id one-to-one (the index is always equal the token_id it corresponds to). The indices returned by this method allow us to do that.
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
weight_loader
¶
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
VocabParallelEmbeddingShardIndices
dataclass
¶
Indices for a shard of a vocab parallel embedding.
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
__init__
¶
__init__(
padded_org_vocab_start_index: int,
padded_org_vocab_end_index: int,
padded_added_vocab_start_index: int,
padded_added_vocab_end_index: int,
org_vocab_start_index: int,
org_vocab_end_index: int,
added_vocab_start_index: int,
added_vocab_end_index: int,
) -> None
__post_init__
¶
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
get_masked_input_and_mask
¶
get_masked_input_and_mask(
input_: Tensor,
org_vocab_start_index: int,
org_vocab_end_index: int,
num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/vocab_parallel_embedding.py
pad_vocab_size
¶
pad_vocab_size(
vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE,
) -> int
Pad the vocab size to the given value.
vocab_range_from_global_vocab_size
¶
vocab_range_from_global_vocab_size(
global_vocab_size: int,
rank: int,
world_size: int,
offset: int = 0,
) -> Sequence[int]