Skip to content

vllm.model_executor.models.registry

Whenever you add an architecture to this page, please also update tests/models/registry.py with example HuggingFace models for it.

ModelRegistry module-attribute

ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for (model_arch, (mod_relname, cls_name)) in items()
    }
)

_CROSS_ENCODER_MODELS module-attribute

_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": (
        "bert",
        "BertForSequenceClassification",
    ),
    "RobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
    "Qwen2ForSequenceClassification": (
        "qwen2",
        "Qwen2ForSequenceClassification",
    ),
    "Qwen3ForSequenceClassification": (
        "qwen3",
        "Qwen3ForSequenceClassification",
    ),
}

_EMBEDDING_MODELS module-attribute

_EMBEDDING_MODELS = {
    "BertModel": ("bert", "BertEmbeddingModel"),
    "DeciLMForCausalLM": (
        "nemotron_nas",
        "DeciLMForCausalLM",
    ),
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
    "GPT2ForSequenceClassification": (
        "gpt2",
        "GPT2ForSequenceClassification",
    ),
    "GritLM": ("gritlm", "GritLM"),
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
    "InternLM2ForRewardModel": (
        "internlm2",
        "InternLM2ForRewardModel",
    ),
    "JambaForSequenceClassification": (
        "jamba",
        "JambaForSequenceClassification",
    ),
    "LlamaModel": ("llama", "LlamaForCausalLM"),
    None: {
        k: (mod, arch)
        for (k, (mod, arch)) in items()
        if arch == "LlamaForCausalLM"
    },
    "MistralModel": ("llama", "LlamaForCausalLM"),
    "ModernBertModel": ("modernbert", "ModernBertModel"),
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2ForRewardModel": (
        "qwen2_rm",
        "Qwen2ForRewardModel",
    ),
    "Qwen2ForProcessRewardModel": (
        "qwen2_rm",
        "Qwen2ForProcessRewardModel",
    ),
    "RobertaForMaskedLM": (
        "roberta",
        "RobertaEmbeddingModel",
    ),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
    "TeleChat2ForCausalLM": (
        "telechat2",
        "TeleChat2ForCausalLM",
    ),
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
    ),
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
    "Qwen2VLForConditionalGeneration": (
        "qwen2_vl",
        "Qwen2VLForConditionalGeneration",
    ),
    "PrithviGeoSpatialMAE": (
        "prithvi_geospatial_mae",
        "PrithviGeoSpatialMAE",
    ),
}

_MULTIMODAL_MODELS module-attribute

_MULTIMODAL_MODELS = {
    "AriaForConditionalGeneration": (
        "aria",
        "AriaForConditionalGeneration",
    ),
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
    ),
    "Blip2ForConditionalGeneration": (
        "blip2",
        "Blip2ForConditionalGeneration",
    ),
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
    ),
    "DeepseekVLV2ForCausalLM": (
        "deepseek_vl2",
        "DeepseekVLV2ForCausalLM",
    ),
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
    "Gemma3ForConditionalGeneration": (
        "gemma3_mm",
        "Gemma3ForConditionalGeneration",
    ),
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
    "Glm4vForConditionalGeneration": (
        "glm4_1v",
        "Glm4vForConditionalGeneration",
    ),
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
    ),
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
    "SmolVLMForConditionalGeneration": (
        "smolvlm",
        "SmolVLMForConditionalGeneration",
    ),
    "KeyeForConditionalGeneration": (
        "keye",
        "KeyeForConditionalGeneration",
    ),
    "KimiVLForConditionalGeneration": (
        "kimi_vl",
        "KimiVLForConditionalGeneration",
    ),
    "LlavaForConditionalGeneration": (
        "llava",
        "LlavaForConditionalGeneration",
    ),
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
    ),
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
    ),
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
    ),
    "MantisForConditionalGeneration": (
        "llava",
        "MantisForConditionalGeneration",
    ),
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
    ),
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
    ),
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "Ovis": ("ovis", "Ovis"),
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
    "PixtralForConditionalGeneration": (
        "pixtral",
        "PixtralForConditionalGeneration",
    ),
    "QwenVLForConditionalGeneration": (
        "qwen_vl",
        "QwenVLForConditionalGeneration",
    ),
    "Qwen2VLForConditionalGeneration": (
        "qwen2_vl",
        "Qwen2VLForConditionalGeneration",
    ),
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
    ),
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
    ),
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
    ),
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
    ),
    "UltravoxModel": ("ultravox", "UltravoxModel"),
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "TarsierForConditionalGeneration": (
        "tarsier",
        "TarsierForConditionalGeneration",
    ),
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
    ),
    "Florence2ForConditionalGeneration": (
        "florence2",
        "Florence2ForConditionalGeneration",
    ),
    "MllamaForConditionalGeneration": (
        "mllama",
        "MllamaForConditionalGeneration",
    ),
    "Llama4ForConditionalGeneration": (
        "mllama4",
        "Llama4ForConditionalGeneration",
    ),
    "SkyworkR1VChatModel": (
        "skyworkr1v",
        "SkyworkR1VChatModel",
    ),
    "WhisperForConditionalGeneration": (
        "whisper",
        "WhisperForConditionalGeneration",
    ),
}

_SPECULATIVE_DECODING_MODELS module-attribute

_SPECULATIVE_DECODING_MODELS = {
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
    "EAGLEModel": ("eagle", "EAGLE"),
    "EagleLlamaForCausalLM": (
        "llama_eagle",
        "EagleLlamaForCausalLM",
    ),
    "EagleMiniCPMForCausalLM": (
        "minicpm_eagle",
        "EagleMiniCPMForCausalLM",
    ),
    "Eagle3LlamaForCausalLM": (
        "llama_eagle3",
        "Eagle3LlamaForCausalLM",
    ),
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": (
        "mlp_speculator",
        "MLPSpeculator",
    ),
}

_SUBPROCESS_COMMAND module-attribute

_SUBPROCESS_COMMAND = [
    executable,
    "-m",
    "vllm.model_executor.models.registry",
]

_T module-attribute

_T = TypeVar('_T')

_TEXT_GENERATION_MODELS module-attribute

_TEXT_GENERATION_MODELS = {
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
    "MiniMaxText01ForCausalLM": (
        "minimax_text_01",
        "MiniMaxText01ForCausalLM",
    ),
    "MiniMaxM1ForCausalLM": (
        "minimax_text_01",
        "MiniMaxText01ForCausalLM",
    ),
    "BaiChuanForCausalLM": (
        "baichuan",
        "BaiChuanForCausalLM",
    ),
    "BaichuanForCausalLM": (
        "baichuan",
        "BaichuanForCausalLM",
    ),
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": (
        "chatglm",
        "ChatGLMForCausalLM",
    ),
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": (
        "nemotron_nas",
        "DeciLMForCausalLM",
    ),
    "DeepseekForCausalLM": (
        "deepseek",
        "DeepseekForCausalLM",
    ),
    "DeepseekV2ForCausalLM": (
        "deepseek_v2",
        "DeepseekV2ForCausalLM",
    ),
    "DeepseekV3ForCausalLM": (
        "deepseek_v2",
        "DeepseekV3ForCausalLM",
    ),
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
    "Ernie4_5_ForCausalLM": (
        "ernie45",
        "Ernie4_5_ForCausalLM",
    ),
    "Ernie4_5_MoeForCausalLM": (
        "ernie45_moe",
        "Ernie4_5_MoeForCausalLM",
    ),
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "Fairseq2LlamaForCausalLM": (
        "fairseq2_llama",
        "Fairseq2LlamaForCausalLM",
    ),
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
    "Gemma3nForConditionalGeneration": (
        "gemma3n",
        "Gemma3nForConditionalGeneration",
    ),
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": (
        "gpt_bigcode",
        "GPTBigCodeForCausalLM",
    ),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": (
        "gpt_neox",
        "GPTNeoXForCausalLM",
    ),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": (
        "granitemoe",
        "GraniteMoeForCausalLM",
    ),
    "GraniteMoeHybridForCausalLM": (
        "granitemoehybrid",
        "GraniteMoeHybridForCausalLM",
    ),
    "GraniteMoeSharedForCausalLM": (
        "granitemoeshared",
        "GraniteMoeSharedForCausalLM",
    ),
    "GritLM": ("gritlm", "GritLM"),
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
    "HunYuanMoEV1ForCausalLM": (
        "hunyuan_v1_moe",
        "HunYuanMoEV1ForCausalLM",
    ),
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": (
        "internlm2",
        "InternLM2ForCausalLM",
    ),
    "InternLM2VEForCausalLM": (
        "internlm2_ve",
        "InternLM2VEForCausalLM",
    ),
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": (
        "falcon_h1",
        "FalconH1ForCausalLM",
    ),
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": (
        "minicpm3",
        "MiniCPM3ForCausalLM",
    ),
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": (
        "mixtral_quant",
        "MixtralForCausalLM",
    ),
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
    "NemotronForCausalLM": (
        "nemotron",
        "NemotronForCausalLM",
    ),
    "NemotronHForCausalLM": (
        "nemotron_h",
        "NemotronHForCausalLM",
    ),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": (
        "persimmon",
        "PersimmonForCausalLM",
    ),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": (
        "phi3_small",
        "Phi3SmallForCausalLM",
    ),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": (
        "qwen2_moe",
        "Qwen2MoeForCausalLM",
    ),
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": (
        "qwen3_moe",
        "Qwen3MoeForCausalLM",
    ),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": (
        "stablelm",
        "StablelmForCausalLM",
    ),
    "StableLmForCausalLM": (
        "stablelm",
        "StablelmForCausalLM",
    ),
    "Starcoder2ForCausalLM": (
        "starcoder2",
        "Starcoder2ForCausalLM",
    ),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
    "TeleChat2ForCausalLM": (
        "telechat2",
        "TeleChat2ForCausalLM",
    ),
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": (
        "bart",
        "BartForConditionalGeneration",
    ),
}

_TRANSFORMERS_MODELS module-attribute

_TRANSFORMERS_MODELS = {
    "TransformersForCausalLM": (
        "transformers",
        "TransformersForCausalLM",
    )
}

_VLLM_MODELS module-attribute

logger module-attribute

logger = init_logger(__name__)

_BaseRegisteredModel

Bases: ABC

Source code in vllm/model_executor/models/registry.py
class _BaseRegisteredModel(ABC):

    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError

    @abstractmethod
    def load_model_cls(self) -> type[nn.Module]:
        raise NotImplementedError

inspect_model_cls abstractmethod

inspect_model_cls() -> _ModelInfo
Source code in vllm/model_executor/models/registry.py
@abstractmethod
def inspect_model_cls(self) -> _ModelInfo:
    raise NotImplementedError

load_model_cls abstractmethod

load_model_cls() -> type[Module]
Source code in vllm/model_executor/models/registry.py
@abstractmethod
def load_model_cls(self) -> type[nn.Module]:
    raise NotImplementedError

_LazyRegisteredModel dataclass

Bases: _BaseRegisteredModel

Represents a model that has not been imported in the main process.

Source code in vllm/model_executor/models/registry.py
@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)

class_name instance-attribute

class_name: str

module_name instance-attribute

module_name: str

__init__

__init__(module_name: str, class_name: str) -> None

inspect_model_cls

inspect_model_cls() -> _ModelInfo
Source code in vllm/model_executor/models/registry.py
def inspect_model_cls(self) -> _ModelInfo:
    return _run_in_subprocess(
        lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

load_model_cls

load_model_cls() -> type[Module]
Source code in vllm/model_executor/models/registry.py
def load_model_cls(self) -> type[nn.Module]:
    mod = importlib.import_module(self.module_name)
    return getattr(mod, self.class_name)

_ModelInfo dataclass

Source code in vllm/model_executor/models/registry.py
@dataclass(frozen=True)
class _ModelInfo:
    architecture: str
    is_text_generation_model: bool
    is_pooling_model: bool
    supports_cross_encoding: bool
    supports_multimodal: bool
    supports_pp: bool
    has_inner_state: bool
    is_attention_free: bool
    is_hybrid: bool
    has_noops: bool
    supports_transcription: bool
    supports_v0_only: bool

    @staticmethod
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
            architecture=model.__name__,
            is_text_generation_model=is_text_generation_model(model),
            is_pooling_model=True,  # Can convert any model into a pooling model
            supports_cross_encoding=supports_cross_encoding(model),
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
            is_hybrid=is_hybrid(model),
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
            has_noops=has_noops(model),
        )

architecture instance-attribute

architecture: str

has_inner_state instance-attribute

has_inner_state: bool

has_noops instance-attribute

has_noops: bool

is_attention_free instance-attribute

is_attention_free: bool

is_hybrid instance-attribute

is_hybrid: bool

is_pooling_model instance-attribute

is_pooling_model: bool

is_text_generation_model instance-attribute

is_text_generation_model: bool

supports_cross_encoding instance-attribute

supports_cross_encoding: bool

supports_multimodal instance-attribute

supports_multimodal: bool

supports_pp instance-attribute

supports_pp: bool

supports_transcription instance-attribute

supports_transcription: bool

supports_v0_only instance-attribute

supports_v0_only: bool

__init__

__init__(
    architecture: str,
    is_text_generation_model: bool,
    is_pooling_model: bool,
    supports_cross_encoding: bool,
    supports_multimodal: bool,
    supports_pp: bool,
    has_inner_state: bool,
    is_attention_free: bool,
    is_hybrid: bool,
    has_noops: bool,
    supports_transcription: bool,
    supports_v0_only: bool,
) -> None

from_model_cls staticmethod

from_model_cls(model: type[Module]) -> _ModelInfo
Source code in vllm/model_executor/models/registry.py
@staticmethod
def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
    return _ModelInfo(
        architecture=model.__name__,
        is_text_generation_model=is_text_generation_model(model),
        is_pooling_model=True,  # Can convert any model into a pooling model
        supports_cross_encoding=supports_cross_encoding(model),
        supports_multimodal=supports_multimodal(model),
        supports_pp=supports_pp(model),
        has_inner_state=has_inner_state(model),
        is_attention_free=is_attention_free(model),
        is_hybrid=is_hybrid(model),
        supports_transcription=supports_transcription(model),
        supports_v0_only=supports_v0_only(model),
        has_noops=has_noops(model),
    )

_ModelRegistry dataclass

Source code in vllm/model_executor/models/registry.py
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)

    def get_supported_archs(self) -> Set[str]:
        return self.models.keys()

    def register_model(
        self,
        model_arch: str,
        model_cls: Union[type[nn.Module], str],
    ) -> None:
        """
        Register an external model to be used in vLLM.

        `model_cls` can be either:

        - A [`torch.nn.Module`][] class directly referencing the model.
        - A string in the format `<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

        if model_arch in self.models:
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)

            model = _LazyRegisteredModel(*split_str)
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
            model = _RegisteredModel.from_model_cls(model_cls)
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)

        self.models[model_arch] = model

    def _raise_for_unsupported(self, architectures: list[str]):
        all_supported_archs = self.get_supported_archs()

        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")

    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[type[nn.Module]]:
        if model_arch not in self.models:
            return None

        return _try_load_model_cls(model_arch, self.models[model_arch])

    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None

        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _normalize_archs(
        self,
        architectures: Union[str, list[str]],
    ) -> list[str]:
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

        # make sure Transformers backend is put at the last as a fallback
        if len(normalized_arch) != len(architectures):
            normalized_arch.append("TransformersForCausalLM")
        return normalized_arch

    def inspect_model_cls(
        self,
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
        architectures = self._normalize_archs(architectures)

        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
                return (model_info, arch)

        return self._raise_for_unsupported(architectures)

    def resolve_model_cls(
        self,
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)

        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)

        return self._raise_for_unsupported(architectures)

    def is_text_generation_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model

    def is_pooling_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_pooling_model

    def is_cross_encoder_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding

    def is_multimodal_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal

    def is_pp_supported_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp

    def model_has_inner_state(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state

    def is_attention_free_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free

    def is_hybrid_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

    def is_noops_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

    def is_transcription_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

    def is_v1_compatible(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

models class-attribute instance-attribute

models: dict[str, _BaseRegisteredModel] = field(
    default_factory=dict
)

__init__

__init__(
    models: dict[str, _BaseRegisteredModel] = dict(),
) -> None

_normalize_archs

_normalize_archs(
    architectures: Union[str, list[str]],
) -> list[str]
Source code in vllm/model_executor/models/registry.py
def _normalize_archs(
    self,
    architectures: Union[str, list[str]],
) -> list[str]:
    if isinstance(architectures, str):
        architectures = [architectures]
    if not architectures:
        logger.warning("No model architectures are specified")

    # filter out support architectures
    normalized_arch = list(
        filter(lambda model: model in self.models, architectures))

    # make sure Transformers backend is put at the last as a fallback
    if len(normalized_arch) != len(architectures):
        normalized_arch.append("TransformersForCausalLM")
    return normalized_arch

_raise_for_unsupported

_raise_for_unsupported(architectures: list[str])
Source code in vllm/model_executor/models/registry.py
def _raise_for_unsupported(self, architectures: list[str]):
    all_supported_archs = self.get_supported_archs()

    if any(arch in all_supported_archs for arch in architectures):
        raise ValueError(
            f"Model architectures {architectures} failed "
            "to be inspected. Please check the logs for more details.")

    raise ValueError(
        f"Model architectures {architectures} are not supported for now. "
        f"Supported architectures: {all_supported_archs}")

_try_inspect_model_cls

_try_inspect_model_cls(
    model_arch: str,
) -> Optional[_ModelInfo]
Source code in vllm/model_executor/models/registry.py
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
    if model_arch not in self.models:
        return None

    return _try_inspect_model_cls(model_arch, self.models[model_arch])

_try_load_model_cls

_try_load_model_cls(
    model_arch: str,
) -> Optional[type[Module]]
Source code in vllm/model_executor/models/registry.py
def _try_load_model_cls(self,
                        model_arch: str) -> Optional[type[nn.Module]]:
    if model_arch not in self.models:
        return None

    return _try_load_model_cls(model_arch, self.models[model_arch])

get_supported_archs

get_supported_archs() -> Set[str]
Source code in vllm/model_executor/models/registry.py
def get_supported_archs(self) -> Set[str]:
    return self.models.keys()

inspect_model_cls

inspect_model_cls(
    architectures: Union[str, list[str]],
) -> tuple[_ModelInfo, str]
Source code in vllm/model_executor/models/registry.py
def inspect_model_cls(
    self,
    architectures: Union[str, list[str]],
) -> tuple[_ModelInfo, str]:
    architectures = self._normalize_archs(architectures)

    for arch in architectures:
        model_info = self._try_inspect_model_cls(arch)
        if model_info is not None:
            return (model_info, arch)

    return self._raise_for_unsupported(architectures)

is_attention_free_model

is_attention_free_model(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def is_attention_free_model(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return model_cls.is_attention_free

is_cross_encoder_model

is_cross_encoder_model(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def is_cross_encoder_model(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return model_cls.supports_cross_encoding

is_hybrid_model

is_hybrid_model(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def is_hybrid_model(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return model_cls.is_hybrid

is_multimodal_model

is_multimodal_model(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def is_multimodal_model(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return model_cls.supports_multimodal

is_noops_model

is_noops_model(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def is_noops_model(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return model_cls.has_noops

is_pooling_model

is_pooling_model(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def is_pooling_model(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return model_cls.is_pooling_model

is_pp_supported_model

is_pp_supported_model(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def is_pp_supported_model(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return model_cls.supports_pp

is_text_generation_model

is_text_generation_model(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def is_text_generation_model(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return model_cls.is_text_generation_model

is_transcription_model

is_transcription_model(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def is_transcription_model(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return model_cls.supports_transcription

is_v1_compatible

is_v1_compatible(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def is_v1_compatible(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return not model_cls.supports_v0_only

model_has_inner_state

model_has_inner_state(
    architectures: Union[str, list[str]],
) -> bool
Source code in vllm/model_executor/models/registry.py
def model_has_inner_state(
    self,
    architectures: Union[str, list[str]],
) -> bool:
    model_cls, _ = self.inspect_model_cls(architectures)
    return model_cls.has_inner_state

register_model

register_model(
    model_arch: str, model_cls: Union[type[Module], str]
) -> None

Register an external model to be used in vLLM.

model_cls can be either:

  • A torch.nn.Module class directly referencing the model.
  • A string in the format <module>:<class> which can be used to lazily import the model. This is useful to avoid initializing CUDA when importing the model and thus the related error RuntimeError: Cannot re-initialize CUDA in forked subprocess.
Source code in vllm/model_executor/models/registry.py
def register_model(
    self,
    model_arch: str,
    model_cls: Union[type[nn.Module], str],
) -> None:
    """
    Register an external model to be used in vLLM.

    `model_cls` can be either:

    - A [`torch.nn.Module`][] class directly referencing the model.
    - A string in the format `<module>:<class>` which can be used to
      lazily import the model. This is useful to avoid initializing CUDA
      when importing the model and thus the related error
      `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
    """
    if not isinstance(model_arch, str):
        msg = f"`model_arch` should be a string, not a {type(model_arch)}"
        raise TypeError(msg)

    if model_arch in self.models:
        logger.warning(
            "Model architecture %s is already registered, and will be "
            "overwritten by the new model class %s.", model_arch,
            model_cls)

    if isinstance(model_cls, str):
        split_str = model_cls.split(":")
        if len(split_str) != 2:
            msg = "Expected a string in the format `<module>:<class>`"
            raise ValueError(msg)

        model = _LazyRegisteredModel(*split_str)
    elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
        model = _RegisteredModel.from_model_cls(model_cls)
    else:
        msg = ("`model_cls` should be a string or PyTorch model class, "
               f"not a {type(model_arch)}")
        raise TypeError(msg)

    self.models[model_arch] = model

resolve_model_cls

resolve_model_cls(
    architectures: Union[str, list[str]],
) -> tuple[type[Module], str]
Source code in vllm/model_executor/models/registry.py
def resolve_model_cls(
    self,
    architectures: Union[str, list[str]],
) -> tuple[type[nn.Module], str]:
    architectures = self._normalize_archs(architectures)

    for arch in architectures:
        model_cls = self._try_load_model_cls(arch)
        if model_cls is not None:
            return (model_cls, arch)

    return self._raise_for_unsupported(architectures)

_RegisteredModel dataclass

Bases: _BaseRegisteredModel

Represents a model that has already been imported in the main process.

Source code in vllm/model_executor/models/registry.py
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: type[nn.Module]

    @staticmethod
    def from_model_cls(model_cls: type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> type[nn.Module]:
        return self.model_cls

interfaces instance-attribute

interfaces: _ModelInfo

model_cls instance-attribute

model_cls: type[Module]

__init__

__init__(
    interfaces: _ModelInfo, model_cls: type[Module]
) -> None

from_model_cls staticmethod

from_model_cls(model_cls: type[Module])
Source code in vllm/model_executor/models/registry.py
@staticmethod
def from_model_cls(model_cls: type[nn.Module]):
    return _RegisteredModel(
        interfaces=_ModelInfo.from_model_cls(model_cls),
        model_cls=model_cls,
    )

inspect_model_cls

inspect_model_cls() -> _ModelInfo
Source code in vllm/model_executor/models/registry.py
def inspect_model_cls(self) -> _ModelInfo:
    return self.interfaces

load_model_cls

load_model_cls() -> type[Module]
Source code in vllm/model_executor/models/registry.py
def load_model_cls(self) -> type[nn.Module]:
    return self.model_cls

_run

_run() -> None
Source code in vllm/model_executor/models/registry.py
def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))

_run_in_subprocess

_run_in_subprocess(fn: Callable[[], _T]) -> _T
Source code in vllm/model_executor/models/registry.py
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

        # `cloudpickle` allows pickling lambda functions directly
        input_bytes = cloudpickle.dumps((fn, output_filepath))

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

        with open(output_filepath, "rb") as f:
            return pickle.load(f)

_try_inspect_model_cls cached

_try_inspect_model_cls(
    model_arch: str, model: _BaseRegisteredModel
) -> Optional[_ModelInfo]
Source code in vllm/model_executor/models/registry.py
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None

_try_load_model_cls cached

_try_load_model_cls(
    model_arch: str, model: _BaseRegisteredModel
) -> Optional[type[Module]]
Source code in vllm/model_executor/models/registry.py
@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[type[nn.Module]]:
    from vllm.platforms import current_platform
    current_platform.verify_model_arch(model_arch)
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None