Skip to content

vllm.model_executor.models.colbert

ColBERT late interaction model for retrieval and reranking.

ColBERT uses per-token embeddings and late interaction (MaxSim) scoring instead of single-vector representations or cross-encoder concatenation.

Reference: https://arxiv.org/abs/2004.12832

ColBERTModel

Bases: BertEmbeddingModel

ColBERT late interaction model for retrieval/reranking.

This model extends BertEmbeddingModel with a ColBERT-style linear projection layer for per-token embeddings. It supports only: - "token_embed" task: Per-token embeddings for late interaction

ColBERT is fundamentally a per-token embedding model - the linear projection is trained for per-token representations, not for CLS pooling. Use a dedicated dense embedding model if you need single- vector representations.

The ColBERT scoring (MaxSim) is computed externally, either client-side or via the late interaction scoring path in ServingScores.

Attributes:

Name Type Description
colbert_linear

Linear projection from hidden_size to colbert_dim

supports_late_interaction Literal[True]

Flag indicating this model uses late interaction scoring

Source code in vllm/model_executor/models/colbert.py
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColBERTModel(BertEmbeddingModel):
    """ColBERT late interaction model for retrieval/reranking.

    This model extends BertEmbeddingModel with a ColBERT-style linear
    projection layer for per-token embeddings. It supports only:
    - "token_embed" task: Per-token embeddings for late interaction

    ColBERT is fundamentally a per-token embedding model - the linear
    projection is trained for per-token representations, not for CLS
    pooling. Use a dedicated dense embedding model if you need single-
    vector representations.

    The ColBERT scoring (MaxSim) is computed externally, either client-side
    or via the late interaction scoring path in ServingScores.

    Attributes:
        colbert_linear: Linear projection from hidden_size to colbert_dim
        supports_late_interaction: Flag indicating this model uses late
            interaction scoring
    """

    # Mark this model as supporting late interaction scoring
    supports_late_interaction: ClassVar[Literal[True]] = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Get config before calling super().__init__
        config = vllm_config.model_config.hf_config
        self.hidden_size = config.hidden_size
        self.head_dtype = vllm_config.model_config.head_dtype

        # ColBERT dimension - check various config field names used by different
        # ColBERT implementations. If not found in config, will be inferred
        # from loaded weights in load_weights()
        self.colbert_dim: int | None = (
            getattr(config, "colbert_dim", None)
            or getattr(config, "dim", None)
            or getattr(config, "projection_dim", None)
        )

        # Initialize parent (this will call _build_pooler)
        super().__init__(vllm_config=vllm_config, prefix=prefix)

    def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
        return BertModel(vllm_config=vllm_config, prefix=prefix)

    def _build_colbert_linear(self) -> nn.Linear:
        """Build the ColBERT linear projection layer."""
        if self.colbert_dim is None:
            raise ValueError("colbert_dim must be set before building the linear layer")
        return nn.Linear(
            self.hidden_size,
            self.colbert_dim,
            bias=False,
            dtype=self.head_dtype,
        )

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
        # ColBERT linear projection: hidden_size -> colbert_dim
        # Original ColBERT uses bias=False
        # If colbert_dim is not set from config, it will be inferred during
        # load_weights and the linear layer will be created there
        if self.colbert_dim is not None:
            self.colbert_linear = self._build_colbert_linear()
        else:
            # Placeholder - will be created when weights are loaded
            self.colbert_linear = None

        # ColBERT only supports token_embed - it's fundamentally a per-token
        # embedding model.
        return pooler_for_token_embed(
            pooler_config,
            projector=self.colbert_linear,
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        def _strip(name: str) -> str:
            for p in ("model.", "bert."):
                if name.startswith(p):
                    name = name[len(p) :]
            return name

        weights_list = list(weights)
        model_side: list[tuple[str, torch.Tensor]] = []
        colbert_side: list[tuple[str, torch.Tensor]] = []

        for name, weight in weights_list:
            stripped = _strip(name)
            # Handle different checkpoint naming conventions for ColBERT linear
            if stripped in ("linear.weight", "colbert_linear.weight"):
                colbert_side.append(("colbert_linear.weight", weight))
            elif stripped.startswith("linear.") or stripped.startswith(
                "colbert_linear."
            ):
                new_name = stripped.replace("linear.", "colbert_linear.")
                colbert_side.append((new_name, weight))
            else:
                model_side.append((stripped, weight))

        # Load base BERT weights using BertModel.load_weights which handles QKV fusion
        loaded: set[str] = set()
        loaded_model = self.model.load_weights(model_side)
        loaded.update({"model." + n for n in loaded_model})

        # Load ColBERT linear weights
        if colbert_side:
            for name, weight in colbert_side:
                if name == "colbert_linear.weight":
                    # Infer colbert_dim from weights if not set in config
                    if self.colbert_dim is None:
                        # Weight shape is [colbert_dim, hidden_size]
                        self.colbert_dim = weight.shape[0]
                        # Create the linear layer now that we know the dimension
                        self.colbert_linear = self._build_colbert_linear()
                        # Move to the same device as the model's existing parameters
                        device = next(self.model.parameters()).device
                        self.colbert_linear.to(device)
                        # Update the pooler's projector to use the new linear layer
                        self.pooler.head.projector = self.colbert_linear

                    # Load weights directly into the pooler's projector
                    weight = weight.to(self.pooler.head.projector.weight.device)
                    self.pooler.head.projector.weight.data.copy_(weight)
                    loaded.add("pooler.head.projector.weight")
                    break

        return loaded

colbert_dim instance-attribute

colbert_dim: int | None = (
    getattr(config, "colbert_dim", None)
    or getattr(config, "dim", None)
    or getattr(config, "projection_dim", None)
)

head_dtype instance-attribute

head_dtype = head_dtype

hidden_size instance-attribute

hidden_size = hidden_size

supports_late_interaction class-attribute

supports_late_interaction: Literal[True] = True

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/colbert.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    # Get config before calling super().__init__
    config = vllm_config.model_config.hf_config
    self.hidden_size = config.hidden_size
    self.head_dtype = vllm_config.model_config.head_dtype

    # ColBERT dimension - check various config field names used by different
    # ColBERT implementations. If not found in config, will be inferred
    # from loaded weights in load_weights()
    self.colbert_dim: int | None = (
        getattr(config, "colbert_dim", None)
        or getattr(config, "dim", None)
        or getattr(config, "projection_dim", None)
    )

    # Initialize parent (this will call _build_pooler)
    super().__init__(vllm_config=vllm_config, prefix=prefix)

_build_colbert_linear

_build_colbert_linear() -> Linear

Build the ColBERT linear projection layer.

Source code in vllm/model_executor/models/colbert.py
def _build_colbert_linear(self) -> nn.Linear:
    """Build the ColBERT linear projection layer."""
    if self.colbert_dim is None:
        raise ValueError("colbert_dim must be set before building the linear layer")
    return nn.Linear(
        self.hidden_size,
        self.colbert_dim,
        bias=False,
        dtype=self.head_dtype,
    )

_build_model

_build_model(
    vllm_config: VllmConfig, prefix: str = ""
) -> BertModel
Source code in vllm/model_executor/models/colbert.py
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
    return BertModel(vllm_config=vllm_config, prefix=prefix)

_build_pooler

_build_pooler(pooler_config: PoolerConfig) -> Pooler
Source code in vllm/model_executor/models/colbert.py
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
    # ColBERT linear projection: hidden_size -> colbert_dim
    # Original ColBERT uses bias=False
    # If colbert_dim is not set from config, it will be inferred during
    # load_weights and the linear layer will be created there
    if self.colbert_dim is not None:
        self.colbert_linear = self._build_colbert_linear()
    else:
        # Placeholder - will be created when weights are loaded
        self.colbert_linear = None

    # ColBERT only supports token_embed - it's fundamentally a per-token
    # embedding model.
    return pooler_for_token_embed(
        pooler_config,
        projector=self.colbert_linear,
    )

load_weights

load_weights(weights: Iterable[tuple[str, Tensor]])
Source code in vllm/model_executor/models/colbert.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
    def _strip(name: str) -> str:
        for p in ("model.", "bert."):
            if name.startswith(p):
                name = name[len(p) :]
        return name

    weights_list = list(weights)
    model_side: list[tuple[str, torch.Tensor]] = []
    colbert_side: list[tuple[str, torch.Tensor]] = []

    for name, weight in weights_list:
        stripped = _strip(name)
        # Handle different checkpoint naming conventions for ColBERT linear
        if stripped in ("linear.weight", "colbert_linear.weight"):
            colbert_side.append(("colbert_linear.weight", weight))
        elif stripped.startswith("linear.") or stripped.startswith(
            "colbert_linear."
        ):
            new_name = stripped.replace("linear.", "colbert_linear.")
            colbert_side.append((new_name, weight))
        else:
            model_side.append((stripped, weight))

    # Load base BERT weights using BertModel.load_weights which handles QKV fusion
    loaded: set[str] = set()
    loaded_model = self.model.load_weights(model_side)
    loaded.update({"model." + n for n in loaded_model})

    # Load ColBERT linear weights
    if colbert_side:
        for name, weight in colbert_side:
            if name == "colbert_linear.weight":
                # Infer colbert_dim from weights if not set in config
                if self.colbert_dim is None:
                    # Weight shape is [colbert_dim, hidden_size]
                    self.colbert_dim = weight.shape[0]
                    # Create the linear layer now that we know the dimension
                    self.colbert_linear = self._build_colbert_linear()
                    # Move to the same device as the model's existing parameters
                    device = next(self.model.parameters()).device
                    self.colbert_linear.to(device)
                    # Update the pooler's projector to use the new linear layer
                    self.pooler.head.projector = self.colbert_linear

                # Load weights directly into the pooler's projector
                weight = weight.to(self.pooler.head.projector.weight.device)
                self.pooler.head.projector.weight.data.copy_(weight)
                loaded.add("pooler.head.projector.weight")
                break

    return loaded