Skip to content

vllm.model_executor.models.olmo_hybrid

Inference-only OLMo Hybrid model compatible with HuggingFace weights.

OlmoHybridGatedDeltaNet

Bases: Module, MambaBase

Gated DeltaNet linear attention layer for OLMo Hybrid.

This implements the linear attention mechanism that replaces sliding window attention in the hybrid architecture.

Source code in vllm/model_executor/models/olmo_hybrid.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
    """
    Gated DeltaNet linear attention layer for OLMo Hybrid.

    This implements the linear attention mechanism that replaces sliding window
    attention in the hybrid architecture.
    """

    @property
    def mamba_type(self) -> str:
        return "gdn_attention"

    def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
            self.model_config.dtype,
            self.cache_config.mamba_cache_dtype,
            self.cache_config.mamba_ssm_cache_dtype,
        )

    def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
        return MambaStateShapeCalculator.gated_delta_net_state_shape(
            self.tp_size,
            self.num_k_heads,
            self.num_v_heads,
            self.head_k_dim,
            self.head_v_dim,
            self.conv_kernel_size,
            self.num_spec,
        )

    def __init__(
        self,
        config,
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        speculative_config: SpeculativeConfig | None = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.hidden_size = config.hidden_size
        self.num_v_heads = config.linear_num_value_heads
        self.num_k_heads = config.linear_num_key_heads
        self.head_k_dim = config.linear_key_head_dim
        self.head_v_dim = config.linear_value_head_dim
        self.key_dim = self.head_k_dim * self.num_k_heads
        self.value_dim = self.head_v_dim * self.num_v_heads

        self.conv_kernel_size = config.linear_conv_kernel_dim
        self.layer_idx = extract_layer_index(prefix)
        self.activation = config.hidden_act
        self.act = ACT2FN[config.hidden_act]
        self.layer_norm_epsilon = config.rms_norm_eps
        assert getattr(config, "linear_use_gate", True), (
            "OlmoHybridGatedDeltaNet requires linear_use_gate=True"
        )
        self.allow_neg_eigval = getattr(config, "linear_allow_neg_eigval", False)
        self.prefix = prefix

        self.config = config
        self.model_config = model_config
        self.cache_config = cache_config
        self.quant_config = quant_config
        self.speculative_config = speculative_config
        self.num_spec = (
            self.speculative_config.num_speculative_tokens
            if self.speculative_config
            else 0
        )

        # Fused QKVG projection: 1 matmul instead of 4
        self.in_proj_qkvg = MergedColumnParallelLinear(
            input_size=self.hidden_size,
            output_sizes=[self.key_dim, self.key_dim, self.value_dim, self.value_dim],
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.in_proj_qkvg",
        )

        # Separate B and A projections to preserve numerical precision.
        # Fusing these into one matmul changes FP accumulation order for the
        # gating scalars, which compounds through the GDN recurrent state.
        self.b_proj = ColumnParallelLinear(
            input_size=self.hidden_size,
            output_size=self.num_v_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.b_proj",
        )
        self.a_proj = ColumnParallelLinear(
            input_size=self.hidden_size,
            output_size=self.num_v_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.a_proj",
        )

        # Fused conv1d: single parameter instead of 3
        self.conv_dim = self.key_dim * 2 + self.value_dim
        self.conv1d = ColumnParallelLinear(
            input_size=self.conv_kernel_size,
            output_size=self.conv_dim,
            bias=False,
            prefix=f"{prefix}.conv1d",
        )
        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
        delattr(self.conv1d.weight, "weight_loader")
        set_weight_attrs(
            self.conv1d.weight,
            {
                "weight_loader": _make_fused_conv1d_weight_loader(
                    [self.key_dim, self.key_dim, self.value_dim],
                    self.tp_size,
                    self.tp_rank,
                )
            },
        )

        self.dt_bias = nn.Parameter(
            torch.ones(self.num_v_heads // self.tp_size),
        )
        self.A_log = nn.Parameter(
            torch.empty(
                divide(self.num_v_heads, self.tp_size),
            )
        )

        set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
        set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})

        # use eps=1e-5 to match FLA's FusedRMSNormGated
        self.o_norm = RMSNormGated(
            self.head_v_dim,
            eps=1e-5,
            group_size=None,
            norm_before_gate=True,
            device=current_platform.current_device(),
            dtype=config.torch_dtype if hasattr(config, "torch_dtype") else None,
        )

        self.o_proj = RowParallelLinear(
            self.value_dim,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        # FLA triton kernels need a PyTorch-backed allocator for scratch
        # memory (required by triton >= 3.x autotuner). Set once at init.
        set_triton_allocator(current_platform.current_device())

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

    def rearrange_mixed_qkv(self, mixed_qkv):
        if mixed_qkv is None:
            return None, None, None
        query, key, value = torch.split(
            mixed_qkv,
            [
                self.key_dim // self.tp_size,
                self.key_dim // self.tp_size,
                self.value_dim // self.tp_size,
            ],
            dim=-1,
        )

        num_k_heads = self.num_k_heads // self.tp_size
        num_v_heads = self.num_v_heads // self.tp_size

        query = rearrange(query, "l (h d) -> 1 l h d", h=num_k_heads, d=self.head_k_dim)
        key = rearrange(key, "l (h d) -> 1 l h d", h=num_k_heads, d=self.head_k_dim)
        value = rearrange(value, "l (h d) -> 1 l h d", h=num_v_heads, d=self.head_v_dim)

        # GQA expansion if needed
        if num_v_heads > num_k_heads:
            expand_ratio = num_v_heads // num_k_heads
            query = query.unsqueeze(3).expand(-1, -1, -1, expand_ratio, -1)
            query = query.reshape(1, query.shape[1], num_v_heads, self.head_k_dim)
            key = key.unsqueeze(3).expand(-1, -1, -1, expand_ratio, -1)
            key = key.reshape(1, key.shape[1], num_v_heads, self.head_k_dim)

        return query.contiguous(), key.contiguous(), value.contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        output: torch.Tensor,
    ):
        # NOTE: We wrap the ENTIRE linear attention forward (projections +
        # core recurrence + output norm + output projection) in a single
        # custom op, rather than just wrapping the recurrent core like
        # other GDN models (e.g. Qwen3Next) do.
        #
        # Why: torch.compile with inductor generates fused kernels for
        # matmuls and pointwise ops. These fused kernels can differ in
        # floating-point accumulation order from eager-mode cuBLAS,
        # introducing small numerical differences (~1e-7 per op). For
        # standard transformer attention this is harmless because each
        # position is computed independently. But for the GDN recurrent
        # state, these tiny input differences compound at every timestep
        # across the full sequence length, causing severe logprob
        # divergence (e.g. ~15% top-1 agreement with eager baseline).
        #
        # By making the full forward opaque to inductor, the projections
        # and output norm run with eager-mode kernels (cuBLAS, triton),
        # preserving numerical consistency. The tradeoff is reduced
        # compilation speedup (~1.5x vs ~3x), but logprob agreement
        # improves from ~15% to ~83% top-1 vs eager.
        #
        # The remaining ~17% divergence comes from inductor compiling
        # the MLP and transformer attention layers that are NOT wrapped
        # in custom ops -- their small precision differences propagate
        # as inputs to the GDN layers from outside.
        torch.ops.vllm.olmo_hybrid_gdn_full_forward(
            hidden_states,
            output,
            self.prefix,
        )

    def _full_forward(
        self,
        hidden_states: torch.Tensor,
        output: torch.Tensor,
    ):
        num_tokens = hidden_states.size(0)

        # ============================================================
        # Part 1: Input Projection (2 fused matmuls instead of 6)
        # ============================================================
        projected_qkvg, _ = self.in_proj_qkvg(hidden_states)
        conv_dim_sharded = (self.key_dim * 2 + self.value_dim) // self.tp_size
        mixed_qkv = projected_qkvg[..., :conv_dim_sharded]
        gate = projected_qkvg[..., conv_dim_sharded:]

        b, _ = self.b_proj(hidden_states)
        a, _ = self.a_proj(hidden_states)

        # ============================================================
        # Part 2: Core Attention
        # ============================================================
        core_attn_out = torch.zeros(
            (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )

        self._forward_core(
            mixed_qkv=mixed_qkv,
            b=b,
            a=a,
            core_attn_out=core_attn_out,
        )

        # ============================================================
        # Part 3: Output Projection
        # ============================================================
        gate = gate.view(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim)
        core_attn_out_flat = core_attn_out.reshape(-1, core_attn_out.shape[-1])
        gate_flat = gate.reshape(-1, gate.shape[-1])
        core_attn_out_normed = self.o_norm(core_attn_out_flat, gate_flat)
        core_attn_out = core_attn_out_normed.view(
            num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim
        )

        core_attn_out = rearrange(core_attn_out, "l h d -> l (h d)")
        output[:num_tokens], _ = self.o_proj(core_attn_out)

    def _forward_core(
        self,
        mixed_qkv: torch.Tensor,
        b: torch.Tensor,
        a: torch.Tensor,
        core_attn_out: torch.Tensor,
    ):
        """
        Core attention computation (called by custom op).
        """
        forward_context = get_forward_context()
        attn_metadata: AttentionMetadata = forward_context.attn_metadata

        if attn_metadata is None:
            # V1 profile run
            return

        assert isinstance(attn_metadata, dict)
        attn_metadata = attn_metadata[self.prefix]
        assert isinstance(attn_metadata, GDNAttentionMetadata)
        has_initial_state = attn_metadata.has_initial_state
        spec_query_start_loc = attn_metadata.spec_query_start_loc
        non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
        spec_sequence_masks = attn_metadata.spec_sequence_masks
        spec_token_indx = attn_metadata.spec_token_indx
        non_spec_token_indx = attn_metadata.non_spec_token_indx
        spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
        non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]
        conv_state = self_kv_cache[0].transpose(-1, -2)
        ssm_state = self_kv_cache[1]
        num_actual_tokens = attn_metadata.num_actual_tokens
        num_accepted_tokens = attn_metadata.num_accepted_tokens

        mixed_qkv = mixed_qkv[:num_actual_tokens]
        b = b[:num_actual_tokens]
        a = a[:num_actual_tokens]

        conv_weights = self.conv1d.weight.view(
            self.conv1d.weight.size(0), self.conv1d.weight.size(2)
        )

        if spec_sequence_masks is not None:
            if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
                mixed_qkv_spec = mixed_qkv
                mixed_qkv_non_spec = None
            else:
                mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
                mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
        else:
            mixed_qkv_spec = None
            mixed_qkv_non_spec = mixed_qkv

        if spec_sequence_masks is not None:
            mixed_qkv_spec = causal_conv1d_update(
                mixed_qkv_spec,
                conv_state,
                conv_weights,
                None,  # no bias
                self.activation,
                conv_state_indices=spec_state_indices_tensor[:, 0][
                    : attn_metadata.num_spec_decodes
                ],
                num_accepted_tokens=num_accepted_tokens,
                query_start_loc=spec_query_start_loc,
                max_query_len=spec_state_indices_tensor.size(-1),
                validate_data=False,
            )

        if attn_metadata.num_prefills > 0:
            mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
            mixed_qkv_non_spec = causal_conv1d_fn(
                mixed_qkv_non_spec_T,
                conv_weights,
                None,
                activation=self.activation,
                conv_states=conv_state,
                has_initial_state=has_initial_state,
                cache_indices=non_spec_state_indices_tensor,
                query_start_loc=non_spec_query_start_loc,
                metadata=attn_metadata,
            ).transpose(0, 1)
        elif attn_metadata.num_decodes > 0:
            mixed_qkv_non_spec = causal_conv1d_update(
                mixed_qkv_non_spec,
                conv_state,
                conv_weights,
                None,
                self.activation,
                conv_state_indices=non_spec_state_indices_tensor[
                    : attn_metadata.num_decodes
                ],
                validate_data=True,
            )
        else:
            mixed_qkv_non_spec = None

        query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
        query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
            mixed_qkv_non_spec
        )

        g, beta = fused_olmo_hybrid_gdn_gating(
            self.A_log, a, b, self.dt_bias, self.allow_neg_eigval
        )

        if spec_sequence_masks is not None:
            if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
                g_spec = g
                beta_spec = beta
                g_non_spec = None
                beta_non_spec = None
            else:
                g_spec = g.index_select(1, spec_token_indx)
                beta_spec = beta.index_select(1, spec_token_indx)
                g_non_spec = g.index_select(1, non_spec_token_indx)
                beta_non_spec = beta.index_select(1, non_spec_token_indx)
        else:
            g_spec = None
            beta_spec = None
            g_non_spec = g
            beta_non_spec = beta

        if spec_sequence_masks is not None:
            core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
                q=query_spec,
                k=key_spec,
                v=value_spec,
                g=g_spec,
                beta=beta_spec,
                initial_state=ssm_state,
                inplace_final_state=True,
                cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
                ssm_state_indices=spec_state_indices_tensor,
                num_accepted_tokens=num_accepted_tokens,
                use_qk_l2norm_in_kernel=True,
            )
        else:
            core_attn_out_spec, last_recurrent_state = None, None

        if attn_metadata.num_prefills > 0:
            initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
            initial_state[~has_initial_state, ...] = 0
            (
                core_attn_out_non_spec,
                last_recurrent_state,
            ) = chunk_gated_delta_rule(
                q=query_non_spec,
                k=key_non_spec,
                v=value_non_spec,
                g=g_non_spec,
                beta=beta_non_spec,
                initial_state=initial_state,
                output_final_state=True,
                cu_seqlens=non_spec_query_start_loc,
                use_qk_l2norm_in_kernel=True,
            )
            ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
                ssm_state.dtype
            )
        elif attn_metadata.num_decodes > 0:
            core_attn_out_non_spec, last_recurrent_state = (
                fused_recurrent_gated_delta_rule(
                    q=query_non_spec,
                    k=key_non_spec,
                    v=value_non_spec,
                    g=g_non_spec,
                    beta=beta_non_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
                    cu_seqlens=non_spec_query_start_loc[
                        : attn_metadata.num_decodes + 1
                    ],
                    ssm_state_indices=non_spec_state_indices_tensor,
                    use_qk_l2norm_in_kernel=True,
                )
            )
        else:
            core_attn_out_non_spec, last_recurrent_state = None, None

        if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
            merged_out = torch.empty(
                (1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
                dtype=core_attn_out_non_spec.dtype,
                device=core_attn_out_non_spec.device,
            )
            merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
            merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
            core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
        elif spec_sequence_masks is not None:
            core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
        else:
            core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)

_forward_core

_forward_core(
    mixed_qkv: Tensor,
    b: Tensor,
    a: Tensor,
    core_attn_out: Tensor,
)

Core attention computation (called by custom op).

Source code in vllm/model_executor/models/olmo_hybrid.py
def _forward_core(
    self,
    mixed_qkv: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    core_attn_out: torch.Tensor,
):
    """
    Core attention computation (called by custom op).
    """
    forward_context = get_forward_context()
    attn_metadata: AttentionMetadata = forward_context.attn_metadata

    if attn_metadata is None:
        # V1 profile run
        return

    assert isinstance(attn_metadata, dict)
    attn_metadata = attn_metadata[self.prefix]
    assert isinstance(attn_metadata, GDNAttentionMetadata)
    has_initial_state = attn_metadata.has_initial_state
    spec_query_start_loc = attn_metadata.spec_query_start_loc
    non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
    spec_sequence_masks = attn_metadata.spec_sequence_masks
    spec_token_indx = attn_metadata.spec_token_indx
    non_spec_token_indx = attn_metadata.non_spec_token_indx
    spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
    non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
    self_kv_cache = self.kv_cache[forward_context.virtual_engine]
    conv_state = self_kv_cache[0].transpose(-1, -2)
    ssm_state = self_kv_cache[1]
    num_actual_tokens = attn_metadata.num_actual_tokens
    num_accepted_tokens = attn_metadata.num_accepted_tokens

    mixed_qkv = mixed_qkv[:num_actual_tokens]
    b = b[:num_actual_tokens]
    a = a[:num_actual_tokens]

    conv_weights = self.conv1d.weight.view(
        self.conv1d.weight.size(0), self.conv1d.weight.size(2)
    )

    if spec_sequence_masks is not None:
        if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
            mixed_qkv_spec = mixed_qkv
            mixed_qkv_non_spec = None
        else:
            mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
            mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
    else:
        mixed_qkv_spec = None
        mixed_qkv_non_spec = mixed_qkv

    if spec_sequence_masks is not None:
        mixed_qkv_spec = causal_conv1d_update(
            mixed_qkv_spec,
            conv_state,
            conv_weights,
            None,  # no bias
            self.activation,
            conv_state_indices=spec_state_indices_tensor[:, 0][
                : attn_metadata.num_spec_decodes
            ],
            num_accepted_tokens=num_accepted_tokens,
            query_start_loc=spec_query_start_loc,
            max_query_len=spec_state_indices_tensor.size(-1),
            validate_data=False,
        )

    if attn_metadata.num_prefills > 0:
        mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
        mixed_qkv_non_spec = causal_conv1d_fn(
            mixed_qkv_non_spec_T,
            conv_weights,
            None,
            activation=self.activation,
            conv_states=conv_state,
            has_initial_state=has_initial_state,
            cache_indices=non_spec_state_indices_tensor,
            query_start_loc=non_spec_query_start_loc,
            metadata=attn_metadata,
        ).transpose(0, 1)
    elif attn_metadata.num_decodes > 0:
        mixed_qkv_non_spec = causal_conv1d_update(
            mixed_qkv_non_spec,
            conv_state,
            conv_weights,
            None,
            self.activation,
            conv_state_indices=non_spec_state_indices_tensor[
                : attn_metadata.num_decodes
            ],
            validate_data=True,
        )
    else:
        mixed_qkv_non_spec = None

    query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
    query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
        mixed_qkv_non_spec
    )

    g, beta = fused_olmo_hybrid_gdn_gating(
        self.A_log, a, b, self.dt_bias, self.allow_neg_eigval
    )

    if spec_sequence_masks is not None:
        if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
            g_spec = g
            beta_spec = beta
            g_non_spec = None
            beta_non_spec = None
        else:
            g_spec = g.index_select(1, spec_token_indx)
            beta_spec = beta.index_select(1, spec_token_indx)
            g_non_spec = g.index_select(1, non_spec_token_indx)
            beta_non_spec = beta.index_select(1, non_spec_token_indx)
    else:
        g_spec = None
        beta_spec = None
        g_non_spec = g
        beta_non_spec = beta

    if spec_sequence_masks is not None:
        core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
            q=query_spec,
            k=key_spec,
            v=value_spec,
            g=g_spec,
            beta=beta_spec,
            initial_state=ssm_state,
            inplace_final_state=True,
            cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
            ssm_state_indices=spec_state_indices_tensor,
            num_accepted_tokens=num_accepted_tokens,
            use_qk_l2norm_in_kernel=True,
        )
    else:
        core_attn_out_spec, last_recurrent_state = None, None

    if attn_metadata.num_prefills > 0:
        initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
        initial_state[~has_initial_state, ...] = 0
        (
            core_attn_out_non_spec,
            last_recurrent_state,
        ) = chunk_gated_delta_rule(
            q=query_non_spec,
            k=key_non_spec,
            v=value_non_spec,
            g=g_non_spec,
            beta=beta_non_spec,
            initial_state=initial_state,
            output_final_state=True,
            cu_seqlens=non_spec_query_start_loc,
            use_qk_l2norm_in_kernel=True,
        )
        ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
            ssm_state.dtype
        )
    elif attn_metadata.num_decodes > 0:
        core_attn_out_non_spec, last_recurrent_state = (
            fused_recurrent_gated_delta_rule(
                q=query_non_spec,
                k=key_non_spec,
                v=value_non_spec,
                g=g_non_spec,
                beta=beta_non_spec,
                initial_state=ssm_state,
                inplace_final_state=True,
                cu_seqlens=non_spec_query_start_loc[
                    : attn_metadata.num_decodes + 1
                ],
                ssm_state_indices=non_spec_state_indices_tensor,
                use_qk_l2norm_in_kernel=True,
            )
        )
    else:
        core_attn_out_non_spec, last_recurrent_state = None, None

    if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
        merged_out = torch.empty(
            (1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
            dtype=core_attn_out_non_spec.dtype,
            device=core_attn_out_non_spec.device,
        )
        merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
        merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
        core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
    elif spec_sequence_masks is not None:
        core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
    else:
        core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)

_make_fused_conv1d_weight_loader

_make_fused_conv1d_weight_loader(dims, tp_size, tp_rank)

Weight loader for loading separate HF conv weights into a fused conv1d.

list of original (un-sharded) dims per section,

e.g. [key_dim, key_dim, value_dim]

Source code in vllm/model_executor/models/olmo_hybrid.py
def _make_fused_conv1d_weight_loader(dims, tp_size, tp_rank):
    """Weight loader for loading separate HF conv weights into a fused conv1d.

    dims: list of original (un-sharded) dims per section,
          e.g. [key_dim, key_dim, value_dim]
    """
    sharded_dims = [d // tp_size for d in dims]

    def weight_loader(param, loaded_weight, loaded_shard_id=None):
        if loaded_weight.dim() == 2:
            loaded_weight = loaded_weight.unsqueeze(1)
        dim = dims[loaded_shard_id]
        shard_size = dim // tp_size
        tp_start = tp_rank * shard_size
        sharded_weight = loaded_weight[tp_start : tp_start + shard_size]
        offset = sum(sharded_dims[:loaded_shard_id])
        param.data[offset : offset + shard_size].copy_(sharded_weight)

    return weight_loader

olmo_hybrid_gdn_full_forward

olmo_hybrid_gdn_full_forward(
    hidden_states: Tensor, output: Tensor, layer_name: str
) -> None

Full linear attention forward wrapped as a custom op.

Prevents inductor from compiling the projections around the GDN core, which would introduce numerical divergence that compounds through the recurrent state.

Source code in vllm/model_executor/models/olmo_hybrid.py
def olmo_hybrid_gdn_full_forward(
    hidden_states: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
) -> None:
    """Full linear attention forward wrapped as a custom op.

    Prevents inductor from compiling the projections around the GDN core,
    which would introduce numerical divergence that compounds through
    the recurrent state.
    """
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
    self._full_forward(
        hidden_states=hidden_states,
        output=output,
    )

olmo_hybrid_gdn_full_forward_fake

olmo_hybrid_gdn_full_forward_fake(
    hidden_states: Tensor, output: Tensor, layer_name: str
) -> None

Fake implementation for torch.compile.

Source code in vllm/model_executor/models/olmo_hybrid.py
def olmo_hybrid_gdn_full_forward_fake(
    hidden_states: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
) -> None:
    """Fake implementation for torch.compile."""
    return