Home

Published

- 47 min read

Beyond Attention: SSMs, Linear Attention & Hybrid Architectures

img of Beyond Attention: SSMs, Linear Attention & Hybrid Architectures

The future of sequence modeling isn’t attention OR recurrence — it’s knowing when to attend precisely and when to flow efficiently.

1. Why We Need Alternatives to Attention

The Transformer’s self-attention mechanism computes O(N2)O(N^2) in both time and memory, where NN is the sequence length. For a 128K-token context window, the attention matrix contains 16.4 billion entries per layer per head. This quadratic scaling creates three concrete bottlenecks that are becoming increasingly painful in practice.

The Agentic RL Rollout Bottleneck

In reinforcement learning with verifiable rewards (RLVR), the training loop is: prompt → model generates rollout → environment verifies → reward signal → policy update. The generation (rollout) phase is autoregressive: each token depends on all previous tokens. With standard attention, this means:

   Step 1:     attend to 1 token     → 1 operation
Step 100:   attend to 100 tokens  → 100 operations
Step 10,000: attend to 10K tokens → 10,000 operations

Even with KV caching, each new token requires attending to the entire cached history. For agentic tasks where the model interacts with tools, writes code, and iterates over multiple steps, rollouts easily reach 50K–100K tokens. The generation phase consumes 80-90% of total training time.

This is not a theoretical concern. Companies like Poolside have built elaborate async architectures specifically to hide this latency. But what if the model itself were simply faster at generation?

The Memory Wall

Modern reasoning tasks demand long contexts. The KV cache grows linearly: for each new token, we store a key and value vector for every layer and every head. At 128K tokens with a 7B parameter model in bf16, the KV cache alone exceeds the memory of most GPUs.

Task TypeContext LengthKV Cache (7B, bf16)
Simple QA2K~256 MB
Code generation8K~1 GB
Multi-step agentic32K~4 GB
Repository-level coding128K~16 GB
Long-horizon planning256K+~32 GB+

This creates a direct tension: we want models that can reason over long horizons, but attention makes long contexts prohibitively expensive.

Inference Cost at Scale

For production deployment, the economics are stark:

  • Prefill (processing the prompt): O(N2)O(N^2) — manageable with parallelism
  • Decode (generating each new token): O(N)O(N) per token with KV cache
  • Total decode cost for L output tokens: O(N×L)O(N \times L)

When NN is large (long context) and LL is large (long generation), this becomes the dominant cost. A model that could maintain O(1)O(1) per-token generation — independent of context length — would fundamentally change the economics.

The Dream: Train Parallel, Infer Recurrent

The ideal architecture combines Transformer-like parallel training with RNN-like O(1)O(1) inference:

                       Training         Inference (per step)    State Size
Standard Attention: O(N²)           O(N) with KV cache      O(N·d) grows
Linear Attention:   O(N)            O(1)                    O(d²) fixed
S4:                 O(N log N)      O(1)                    O(d) fixed
Mamba:              O(N)            O(1)                    O(d) fixed
Mamba-2:            O(N)            O(1) + tensor cores     O(d) fixed

This isn’t a pipe dream. The key realization is that certain sequence models admit dual computation forms: one parallel (for training) and one recurrent (for inference), both computing the exact same function.

The Recall Problem

Before we dive in, we should acknowledge the elephant in the room: recall.

Standard attention can, in principle, attend to any token in the history with equal ease. This makes it excellent at tasks requiring precise retrieval — “What was the function name defined on line 47?” The attention matrix can place arbitrary weight on any position.

SSMs and linear attention variants compress the entire history into a fixed-size state. This compression is inherently lossy. The question is: how much recall capability do we sacrifice for efficiency, and can we get it back?

This tension — efficiency vs. recall — is the central design challenge of the field, and it’s what drives the evolution from pure SSMs to hybrid architectures that mix attention with SSM layers.

Roadmap

In this article, we’ll trace the complete evolution:

  1. S4: The first SSM to match Transformers, via the convolution-recurrence duality
  2. Mamba: Making SSMs content-aware with selective parameters
  3. Mamba-2 (SSD): Revealing that SSMs and attention are two views of the same computation
  4. Linear Attention: The bridge connecting Transformers to RNNs through the associativity trick
  5. The Landscape: GLA, RWKV, RetNet, and the explosion of linear-time variants
  6. Hybrid Architectures: Jamba, Zamba, Samba — the pragmatic solution that combines attention and SSM layers
  7. Practical Implications: When to use what, and where the field is heading

Each section includes mathematical formulations, code implementations, and concrete complexity analysis. By the end, you should have a clear mental model of the entire SSM/linear-attention design space and understand which architecture fits which use case.


2. S4 — The Structured State Space Revolution

The SSM Foundation

A State Space Model describes a system that maintains an internal state x which evolves over time based on input u and produces output y:

Continuous-time SSM:

dxdt=Ax+Bu(state equation)\frac{dx}{dt} = Ax + Bu \qquad \text{(state equation)} y=Cx+Du(output equation)y = Cx + Du \qquad \text{(output equation)}

Where:

  • xRN\mathbf{x} \in \mathbb{R}^N is the hidden state (NN = state dimension, typically 16-64)
  • uR\mathbf{u} \in \mathbb{R} is the input signal
  • ARN×NA \in \mathbb{R}^{N \times N} is the state transition matrix
  • BRN×1B \in \mathbb{R}^{N \times 1} maps input to state
  • CR1×NC \in \mathbb{R}^{1 \times N} maps state to output
  • DRD \in \mathbb{R} is a skip connection (often set to 1)

To process discrete sequences (tokens), we need to discretize these continuous equations. S4 uses Zero-Order Hold (ZOH) with step size Δ\Delta:

Aˉ=eΔAI+ΔA(for small Δ)\bar{A} = e^{\Delta \cdot A} \approx I + \Delta \cdot A \quad \text{(for small } \Delta\text{)} Bˉ=(AˉI)A1BΔB(first-order approximation)\bar{B} = (\bar{A} - I) A^{-1} B \approx \Delta \cdot B \quad \text{(first-order approximation)}

This gives us the discrete SSM:

xt=Aˉxt1+Bˉut(state update)x_t = \bar{A} \cdot x_{t-1} + \bar{B} \cdot u_t \qquad \text{(state update)} yt=Cxt+Dut(output)y_t = C \cdot x_t + D \cdot u_t \qquad \text{(output)}

In Python:

   def discretize_zoh(A, B, delta):
    """Zero-Order Hold discretization.
    A: (d_model, d_state) — diagonal state transition
    B: (d_model, d_state) — input-to-state mapping
    delta: (d_model,)     — step size per dimension
    """
    A_bar = torch.exp(delta.unsqueeze(-1) * A)  # exp(Δ·A)
    B_bar = delta.unsqueeze(-1) * B              # Δ·B (first-order)
    return A_bar, B_bar

An alternative discretization method is the bilinear transform (Tustin’s method), which is more accurate but slightly more expensive:

   def discretize_bilinear(A, B, delta):
    """Bilinear (Tustin) discretization — more accurate for oscillatory systems"""
    dA = delta.unsqueeze(-1) * A * 0.5
    A_bar = (1 + dA) / (1 - dA)    # (1 + Δ/2·A) / (1 - Δ/2·A)
    B_bar = delta.unsqueeze(-1) * B / (1 - dA)
    return A_bar, B_bar

ZOH is simpler and works well for monotonic decay dynamics; bilinear preserves frequency content better. S4 uses ZOH; some follow-up works use bilinear.

The Dual Forms: S4’s Key Insight

Here’s the breakthrough: the discrete SSM can be computed in two mathematically equivalent ways.

Form 1: Recurrence — O(1)O(1) per step, perfect for inference

Process tokens one at a time, maintaining state:

   def ssm_recurrent(u, A_bar, B_bar, C, D):
    """O(N) total, O(1) per step — perfect for autoregressive generation"""
    batch, L, d_model = u.shape
    x = torch.zeros(batch, d_model, d_state)  # fixed-size state

    outputs = []
    for t in range(L):
        x = A_bar * x + B_bar * u[:, t, :].unsqueeze(-1)
        y_t = (C * x).sum(dim=-1) + D * u[:, t, :]
        outputs.append(y_t)

    return torch.stack(outputs, dim=1)

Pros: Constant memory, O(1)O(1) per token — no matter how long the context Cons: Sequential — can’t parallelize across time steps, so training is slow

Form 2: Convolution — O(NlogN)O(N \log N) parallel, perfect for training

Unrolling the recurrence reveals a convolution. Consider the first few outputs:

y0=CBˉu0y_0 = C \bar{B} u_0 y1=CAˉBˉu0+CBˉu1y_1 = C \bar{A} \bar{B} u_0 + C \bar{B} u_1 y2=CAˉ2Bˉu0+CAˉBˉu1+CBˉu2y_2 = C \bar{A}^2 \bar{B} u_0 + C \bar{A} \bar{B} u_1 + C \bar{B} u_2

The pattern reveals a convolution with kernel KK:

K=[CBˉ,  CAˉBˉ,  CAˉ2Bˉ,  ,  CAˉL1Bˉ]K = \bigl[C\bar{B},\; C\bar{A}\bar{B},\; C\bar{A}^2\bar{B},\; \dots,\; C\bar{A}^{L-1}\bar{B}\bigr]

Then: y=Ku\mathbf{y} = K * \mathbf{u} (1D convolution), computable via FFT in O(NlogN)O(N \log N).

   def compute_ssm_kernel(A_bar, B_bar, C, L):
    """Compute convolution kernel K[t] = C · A_bar^t · B_bar

    The kernel only depends on A, B, C, Δ — NOT on the input.
    Precompute once for a given sequence length and reuse for all sequences.
    """
    powers = torch.arange(L, device=A_bar.device).float()
    A_powers = A_bar.unsqueeze(-1) ** powers    # (d_model, d_state, L)
    kernel = torch.einsum('dn,dn,dnl->dl', C, B_bar, A_powers)
    return kernel  # (d_model, L)

def ssm_conv_mode(u, K, D):
    """O(N log N) via FFT — perfect for training"""
    u_t = u.transpose(1, 2)         # (batch, d_model, L)
    n_fft = 2 * u.shape[1]          # pad to avoid circular convolution

    u_fft = torch.fft.rfft(u_t, n=n_fft, dim=-1)
    K_fft = torch.fft.rfft(K, n=n_fft, dim=-1)
    y_fft = u_fft * K_fft.unsqueeze(0)
    y = torch.fft.irfft(y_fft, n=n_fft, dim=-1)[..., :u.shape[1]]

    return y.transpose(1, 2) + D * u

Pros: Fully parallel, exploits FFT hardware Cons: O(N)O(N) memory for the kernel, O(NlogN)O(N \log N) time

Why Both Forms Compute the Same Thing

This equivalence holds because the SSM is a Linear Time-Invariant (LTI) system. For any LTI system:

  • The impulse response (kernel K) completely characterizes the input-output relationship
  • The output is the convolution of input with the impulse response
  • But you can also compute this incrementally via the recurrence

The kernel only depends on A, B, C, and Δ — not on the input. So you can precompute it once for a given sequence length and reuse it for all sequences. During training, you use the convolution form (parallel). During inference, you switch to recurrence form (O(1)O(1) per step). The model computes the exact same function either way.

This dual computation pattern — parallel for training, recurrent for inference — is the central principle that all subsequent SSM and linear attention architectures exploit.

HiPPO: Why Initialization Matters

S4’s other key contribution is the HiPPO (High-order Polynomial Projection Operator) framework for initializing the A matrix.

Random initialization of A leads to either:

  • Rapid decay: information is forgotten too quickly (eigenvalues too negative)
  • Instability: states explode (eigenvalues with positive real parts)

HiPPO provides a principled initialization based on optimally compressing the history using orthogonal polynomials (specifically, Legendre polynomials). The idea: at each timestep, the state should be the best polynomial approximation of the entire input history seen so far.

The diagonal approximation:

   def make_hippo_diagonal(N):
    """A_n = -(n + 1/2) for n = 0, 1, ..., N-1
    Each dimension decays at a different rate, creating multi-scale memory:
      - Dimension 0:  decay 0.5  (slow, captures long-range patterns)
      - Dimension 15: decay 15.5 (medium, captures mid-range context)
      - Dimension 63: decay 63.5 (fast, captures fine local detail)
    """
    return -(torch.arange(N).float() + 0.5)

This multi-scale memory is what allows S4 to handle sequences of 16,000+ tokens — far beyond what LSTMs and GRUs could achieve. The slow dimensions remember events from thousands of steps ago; the fast dimensions capture recent local patterns. Together, they form a multi-resolution representation of the entire history.

S4: Putting It All Together

The complete S4 layer combines discretization, dual-mode computation, and HiPPO initialization:

   class S4Layer(nn.Module):
    def __init__(self, d_model, d_state=64):
        super().__init__()
        # SSM parameters (diagonal form for efficiency)
        self.A_log = nn.Parameter(
            torch.log(make_hippo_diagonal(d_state).abs() + 1e-4)
            .unsqueeze(0).expand(d_model, -1).clone()
        )
        self.B = nn.Parameter(torch.randn(d_model, d_state) * 0.1)
        self.C = nn.Parameter(torch.randn(d_model, d_state) * 0.1)
        self.D = nn.Parameter(torch.ones(d_model))
        self.log_delta = nn.Parameter(
            torch.log(torch.ones(d_model) * 0.01)  # Small initial step size
        )
        self._kernel_cache = None
        self._cached_length = 0

    def forward(self, u, state=None):
        """Training: convolution form. Inference: recurrence form."""
        A = -torch.exp(self.A_log)
        delta = torch.exp(self.log_delta)
        A_bar, B_bar = discretize_zoh(A, self.B, delta)

        if state is not None:
            # Inference mode: O(1) per step recurrence
            return ssm_recurrent(u, A_bar, B_bar, self.C, self.D, state)
        else:
            # Training mode: O(N log N) convolution
            L = u.shape[1]
            K = compute_ssm_kernel(A_bar, B_bar, self.C, L)
            return ssm_conv_mode(u, K, self.D)

    def step(self, u_t, state):
        """O(1) single step for autoregressive generation.
        u_t: (batch, d_model) — single token
        state: (batch, d_model, d_state) — running SSM state
        """
        A = -torch.exp(self.A_log)
        delta = torch.exp(self.log_delta)
        A_bar, B_bar = discretize_zoh(A, self.B, delta)

        state = A_bar * state + B_bar * u_t.unsqueeze(-1)
        y_t = (self.C * state).sum(dim=-1) + self.D * u_t
        return y_t, state

Note the step() method: during autoregressive generation, we call this once per token. The state is a fixed-size tensor of shape (batch, d_model, d_state) — typically around 64 floats per dimension. No matter how many tokens we’ve generated, the state never grows.

S4’s Limitation: Time-Invariance

S4 has one fundamental weakness: A, B, C, and Δ are fixed parameters — they don’t depend on the input. This means:

  • Every token is processed with the same transition dynamics
  • The model can’t selectively remember or forget based on content
  • It’s like reading a book where you highlight every word with the same color

Consider: “The capital of France is ___. France is in Europe.”

An S4 model processes both “France” tokens identically. It can’t decide that the first “France” (in a question context) should be strongly encoded while the second (in a statement context) is less important. The kernel K is precomputed and applies uniformly to all inputs.

This is the fundamental tradeoff of LTI systems: the convolution trick that enables O(NlogN)O(N \log N) training requires time-invariance. Making parameters input-dependent breaks the convolution — which means we need a new parallel computation strategy.

S4 Complexity Summary

AspectS4Standard Attention
TrainingO(NlogN)O(N \log N) via FFTO(N2)O(N^2)
Inference per stepO(1)O(1)O(N)O(N) with KV cache
State sizeO(dstate)O(d_{\text{state}}) ≈ 64 floatsO(N×d)O(N \times d) grows with context
Long sequencesExcellent (via HiPPO)Quadratic cost
Content awarenessNone (time-invariant)Full (softmax attention)
Kernel precomputeYes — reuse across sequencesN/A

3. Mamba — Making SSMs Content-Aware

The Selection Mechanism

Mamba’s core innovation is making SSM parameters input-dependent:

S4 (Time-Invariant) — same Aˉ,Bˉ\bar{A}, \bar{B} for all timesteps:

xt=Aˉxt1+Bˉutx_t = \bar{A} \cdot x_{t-1} + \bar{B} \cdot u_t

Mamba (Selective) — parameters depend on current input:

Bt=Linear(ut),Ct=Linear(ut),Δt=softplus(Linear(ut))B_t = \text{Linear}(u_t), \quad C_t = \text{Linear}(u_t), \quad \Delta_t = \text{softplus}(\text{Linear}(u_t)) Aˉt=eΔtA,xt=Aˉtxt1+ΔtBtut\bar{A}_t = e^{\Delta_t \cdot A}, \quad x_t = \bar{A}_t \cdot x_{t-1} + \Delta_t \cdot B_t \cdot u_t

The parameters serve distinct roles:

  • B_t (input-dependent): Controls what information to write to the state — “what should we remember about this token?”
  • C_t (input-dependent): Controls what information to read from the state — “what past context is relevant right now?”
  • Δ_t (input-dependent): Controls the importance of the current token — high Δ means “this token matters, encode it strongly”; low Δ means “pass through, don’t disrupt the state”
  • A (fixed, NOT input-dependent): Controls the base decay rate — how quickly old information fades

Note that A is deliberately NOT input-dependent. This is a design choice: A controls the fundamental decay dynamics of the state, which can be learned but don’t need to vary per-token. Making A input-dependent would increase the parameter count and computation without significant quality gains. B and C control the read/write operations, which do need input-dependent flexibility. Δ controls how much the current timestep “matters,” effectively modulating how strongly A’s decay is applied.

Why Selection Matters: The Copy Task Intuition

Consider: “The capital of France is ___. France is in Europe.”

S4 (time-invariant): Every token gets the same write strength to the state. Both “France” tokens are processed identically. “The” and “is” receive the same encoding effort as “capital” and “France.” The model can’t decide which tokens are important — it’s like a camera with a fixed exposure.

Mamba (selective): The model can:

  • Strongly write the “France” → “capital” association (high Δ, informative B) — this token matters
  • Weakly write filler tokens like “is” and “in” (low Δ) — these are structural, not semantic
  • Selectively forget irrelevant context (via A decay modulated by Δ) — clear space for new important information

This is like a camera with adaptive exposure: bright for important subjects, dark for the background.

The copy task makes this even clearer: given “ABCDE…copy…ABCDE”, the model must remember the initial sequence and reproduce it. S4 treats the “copy” token the same as “A” — it can’t signal “now start reading from state.” Mamba can make Δ very small during the “copy” token (don’t write, just pass through) and then use C to read the relevant stored information.

The Challenge: No More Convolution

With input-dependent parameters, S4’s convolution trick breaks completely. The kernel K=[CBˉ,  CAˉBˉ,  CAˉ2Bˉ,  ]K = [C\bar{B},\; C\bar{A}\bar{B},\; C\bar{A}^2\bar{B},\; \dots] was precomputable because C,Aˉ,BˉC, \bar{A}, \bar{B} were the same for all inputs. In Mamba, BtB_t, CtC_t, and Δt\Delta_t differ for every token — there’s no single kernel to precompute.

Naive approaches:

  • Process sequentially: O(N)O(N) but no parallelism → slow training on GPUs
  • Materialize the full N×NN \times N matrix: O(N2)O(N^2) → defeats the purpose of using an SSM

The Parallel Scan Solution

Mamba’s answer is the parallel scan (also called prefix sum) algorithm. The key insight is that the SSM recurrence is associative:

Given two consecutive operations (A1,b1)(A_1, b_1) and (A2,b2)(A_2, b_2), where each computes xt=Atxt-1+btx_t = A_t \cdot x_{t\text{-}1} + b_t, their composition is:

(A1,b1)(A2,b2)=(A1A2,  A2b1+b2)(A_1, b_1) \oplus (A_2, b_2) = (A_1 \cdot A_2,\; A_2 \cdot b_1 + b_2)

This works because: if x₁ = A₁·x₀ + b₁ and x₂ = A₂·x₁ + b₂, then x₂ = A₁·A₂·x₀ + A₂·b₁ + b₂. The combined operation is still a linear function of x₀, with the same (multiply, add) structure.

Because this operation is associative (but not commutative), we can use parallel prefix computation:

  • Total work: O(N)O(N) — same as sequential
  • Parallel depth: O(logN)O(\log N) — with NN processors, we finish in logN\log N steps
  • In practice: Implemented as a custom CUDA kernel with careful memory management
   # Conceptual parallel scan (actual Mamba uses fused CUDA kernels)
def parallel_scan_associative(A_bars, B_bar_u):
    """
    A_bars:  (batch, L, d_inner, d_state) — per-step decay
    B_bar_u: (batch, L, d_inner, d_state) — per-step input contribution

    Returns: all hidden states x_1, x_2, ..., x_L

    The algorithm works in O(log L) parallel rounds:
    Round 1: combine pairs (1,2), (3,4), (5,6), ...
    Round 2: combine results at distance 2, 4, 8, ...
    After log(L) rounds, all prefix sums are computed.
    """
    # This is O(N) total work, O(log N) depth
    # Real implementation uses mamba_ssm.ops.selective_scan_cuda
    pass

The Selective Scan: Step by Step

Here’s the complete selective scan in explicit sequential form (the parallel scan computes the same result):

   def selective_scan(u, delta, A, B, C, D):
    """Core Mamba operation — sequential reference implementation.

    u:     (batch, L, d_inner) — expanded input
    delta: (batch, L, d_inner) — input-dependent step size
    A:     (d_inner, d_state)  — FIXED state transition (negative values)
    B:     (batch, L, d_state) — input-dependent write gate
    C:     (batch, L, d_state) — input-dependent read gate
    D:     (d_inner,)          — skip connection

    Returns: (batch, L, d_inner) — output sequence
    """
    batch, L, d_inner = u.shape
    d_state = A.shape[1]
    x = torch.zeros(batch, d_inner, d_state)  # Hidden state

    outputs = []
    for t in range(L):
        # Step 1: Discretize with input-dependent delta
        #   Large delta → strong update (token matters)
        #   Small delta → weak update (token is filler)
        delta_A = delta[:, t, :].unsqueeze(-1) * A     # (batch, d_inner, d_state)
        A_bar_t = torch.exp(delta_A)                     # Decay factor

        # Step 2: Input contribution = delta * B * u
        #   This is the "write" operation: encoding the current token
        delta_B_u = (delta[:, t, :].unsqueeze(-1)         # importance
                     * B[:, t, :].unsqueeze(1)             # write gate
                     * u[:, t, :].unsqueeze(-1))           # content

        # Step 3: State update — decay old state, add new input
        x = A_bar_t * x + delta_B_u

        # Step 4: Output — read from state + skip connection
        y_t = (x * C[:, t, :].unsqueeze(1)).sum(dim=-1) + D * u[:, t, :]
        outputs.append(y_t)

    return torch.stack(outputs, dim=1)

The Full Mamba Block

Mamba wraps the selective SSM in a gated architecture inspired by Gated Linear Units:

   class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        d_inner = d_model * expand

        # 1. Input projection → x (main path) and z (gating branch)
        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

        # 2. Depthwise convolution for local context
        self.conv1d = nn.Conv1d(d_inner, d_inner, kernel_size=d_conv,
                                padding=d_conv-1, groups=d_inner)

        # 3. Selective parameter projections
        self.x_proj = nn.Linear(d_inner, d_state * 2, bias=False)  # B, C
        self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)       # delta

        # 4. Fixed A (not input-dependent) — initialized as log-spaced decay
        self.A_log = nn.Parameter(torch.log(
            torch.arange(1, d_state + 1).float()
        ).unsqueeze(0).expand(d_inner, -1).clone())
        self.D = nn.Parameter(torch.ones(d_inner))

        # 5. Output projection
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x):
        # Split into main path and gating path
        xz = self.in_proj(x)                            # (batch, L, 2*d_inner)
        x_inner, z = xz.chunk(2, dim=-1)

        # Local context via depthwise conv1d
        x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :x.shape[1]]
        x_conv = F.silu(x_conv.transpose(1, 2))

        # Generate selective parameters from conv-processed input
        BC = self.x_proj(x_conv)
        B, C = BC.chunk(2, dim=-1)
        delta = F.softplus(self.dt_proj(x_conv))
        A = -torch.exp(self.A_log)

        # Selective scan — the core SSM computation
        y = selective_scan(x_conv, delta, A, B, C, self.D)

        # Gating: y_final = y * SiLU(z)
        y = y * F.silu(z)
        return self.out_proj(y)

Why Conv1d Before the SSM?

The depthwise convolution (kernel size typically 4) serves multiple purposes:

  • Local context window: Provides n-gram-like local information that the SSM can use when computing its selective parameters. Without it, B_t and C_t would depend only on the current token embedding, missing crucial local context.
  • Compensates for smoothing: SSMs tend to smooth over local patterns due to their recurrent nature. The conv captures sharp local features.
  • Computationally cheap: groups=d_inner means each channel is convolved independently — no cross-channel interaction, so the cost is O(dinner×dconv×L)O(d_{\text{inner}} \times d_{\text{conv}} \times L).
  • Analogous to positional cues: Provides local structure similar to how positional encodings work in Transformers.

The Gating Mechanism

   y_final = y_ssm * SiLU(z)

The z branch is a separate projection of the input that acts as a learned gate, similar to the Gated Linear Unit (GLU) family (SwiGLU, GEGLU, etc.). This:

  • Helps gradient flow through the network (z provides a direct path)
  • Provides feature selection independent of the SSM dynamics — some features may be passed through directly
  • Is standard in modern architectures and empirically beneficial

O(1) Inference: The Step Function

For autoregressive generation, we only need to process one token at a time. This is where Mamba’s fixed-state advantage shines:

   def step(self, x_t, conv_state, ssm_state):
    """O(1) per token — the whole point of using SSMs.

    x_t:        (batch, d_model)             — single token embedding
    conv_state: (batch, d_inner, d_conv-1)   — sliding window for conv
    ssm_state:  (batch, d_inner, d_state)    — SSM hidden state

    Returns: y_t, new_conv_state, new_ssm_state
    """
    # Input projection
    xz = self.in_proj(x_t)                   # (batch, 2*d_inner)
    x_inner, z = xz.chunk(2, dim=-1)

    # Conv: shift the sliding window and apply convolution
    conv_full = torch.cat([conv_state, x_inner.unsqueeze(-1)], dim=-1)
    x_conv = (conv_full * self.conv1d.weight.squeeze(1)).sum(dim=-1)
    x_conv = x_conv + self.conv1d.bias
    x_conv = F.silu(x_conv)
    conv_state = conv_full[:, :, 1:]         # Slide window forward

    # Generate selective parameters for this single token
    BC = self.x_proj(x_conv)
    B, C = BC.chunk(2, dim=-1)
    delta = F.softplus(self.dt_proj(x_conv))
    A = -torch.exp(self.A_log)

    # Single SSM step: O(d_inner × d_state) ≈ O(d²)
    A_bar = torch.exp(delta.unsqueeze(-1) * A)
    delta_B_u = delta.unsqueeze(-1) * B.unsqueeze(1) * x_conv.unsqueeze(-1)
    ssm_state = A_bar * ssm_state + delta_B_u
    y = (ssm_state * C.unsqueeze(1)).sum(dim=-1) + self.D * x_conv

    # Gating and output
    y = y * F.silu(z)
    return self.out_proj(y), conv_state, ssm_state

The total state per layer consists of:

  • conv_state: (d_inner, d_conv-1) — sliding window for the convolution, typically (128, 3)384 floats
  • ssm_state: (d_inner, d_state) — the SSM hidden state, typically (128, 16)2,048 floats

Compare to a Transformer’s KV cache per layer at 100K context:

  • KV cache: (2 × n_heads × head_dim × 100K)100M+ values

That’s roughly a 50,000x reduction in per-layer state. This is why Mamba can generate tokens at long contexts without the memory bottleneck that plagues Transformers.

What Mamba Achieves

AspectTransformerMamba
TrainingO(N2d)O(N^2 d)O(Nd2)O(N d^2) via parallel scan
Inference per stepO(Nd)O(Nd)O(d2)O(d^2)independent of NN
KV cache / stateO(Nd)O(Nd) — grows linearlyO(d)O(d)constant
Long contextExpensiveNear-free
Content awarenessFull (softmax)Selective (input-dependent B, C, Δ)
HardwareStandard matmul → tensor coresCustom CUDA kernels

The Falcon Mamba 7B model, trained on 5.8 trillion tokens, surpassed Mistral 7B, Llama 3.1 8B, and Falcon2 11B on the Open LLM Leaderboard — proving that pure SSM architectures can compete at scale.

The Remaining Question

Mamba’s parallel scan relies on custom CUDA kernels hand-tuned for GPU memory hierarchies. This works, but:

  • It’s hard to maintain and port across hardware
  • It can’t leverage tensor cores — the fastest hardware path on modern GPUs (optimized for matrix multiplication, not scan)
  • Custom kernels are a bottleneck for adoption

What if we could express the SSM computation as standard matrix multiplication, letting us use the most optimized hardware path available?

That’s exactly what Mamba-2 achieves.


4. Mamba-2: The State Space Duality Revealed

SSM = Semi-Separable Matrix Multiplication

Mamba-2’s breakthrough is revealing a deep mathematical connection: SSM computation is equivalent to multiplication by a semi-separable matrix. This connects SSMs to attention at a fundamental level and enables using tensor cores.

Consider the SSM output unrolled:

y0=C0B0u0y_0 = C_0 B_0 u_0 y1=C1AB0u0+C1B1u1y_1 = C_1 A B_0 u_0 + C_1 B_1 u_1 y2=C2A2B0u0+C2AB1u1+C2B2u2y_2 = C_2 A^2 B_0 u_0 + C_2 A B_1 u_1 + C_2 B_2 u_2

This can be written as a matrix multiplication y=Mu\mathbf{y} = M \cdot \mathbf{u}, where:

M[i,j]={CiAijBjfor ij (causal)0for i<jM[i,j] = \begin{cases} C_i \cdot A^{i-j} \cdot B_j & \text{for } i \geq j \text{ (causal)} \\ 0 & \text{for } i < j \end{cases}

This matrix MM is semi-separable: each element is a product of three factors — a row factor (CiC_i), a decay factor (AijA^{i-j}), and a column factor (BjB_j). It’s lower-triangular (causal) with exponential decay away from the diagonal.

We can build it explicitly:

   def build_semiseparable_matrix(A, B, C, L):
    """Build the semi-separable matrix M[i,j] = C_i · diag(A^{i-j}) · B_j

    This makes the SSM-attention duality concrete:
    - Standard attention: M = softmax(Q @ K^T / √d)
    - SSD (SSM):         M[i,j] = C_i · A^{i-j} · B_j

    A: (d_state,) — decay values (scalar per state dimension)
    B: (L, d_state) — per-position "key" (write gate)
    C: (L, d_state) — per-position "query" (read gate)
    """
    positions = torch.arange(L)
    diff = positions.unsqueeze(0) - positions.unsqueeze(1)  # (L, L)

    # Decay matrix: A^{i-j} for each state dimension
    D = A.view(-1, 1, 1) ** diff.clamp(min=0).unsqueeze(0)  # (d_state, L, L)

    # Apply causal mask
    causal_mask = torch.tril(torch.ones(L, L))
    D = D * causal_mask.unsqueeze(0)

    # M[i,j] = Σ_k C[i,k] · D[k,i,j] · B[j,k]
    M = torch.einsum('id,dij,jd->ij', C, D, B)
    return M  # (L, L) — the full semi-separable matrix

The SSM-Attention Connection

Now compare with standard attention:

Standard AttentionSSD (Mamba-2)
Matrixsoftmax(QK/d)\text{softmax}(QK^\top / \sqrt{d})M[i,j]=CiAi-jBjM[i,j] = C_i \cdot A^{i\text{-}j} \cdot B_j
Query/KeyQ,KQ, KCC (read gate), BB (write gate)
ValueVVuu (input)
PatternSoftmax → sharp, content-dependentExponential decay → smooth, distance-dependent
NormalizationSoftmax (sum-to-1)None (implicit via decay)
ComplexityO(N2)O(N^2) to materializeO(N)O(N) via recurrence

Both are bilinear in query and key: the output at position ii depends on the interaction between what we want to read (CiC_i / QiQ_i) and what was written (BjB_j / KjK_j). The critical difference:

  • Attention uses softmax normalization → creates sharp attention patterns that can focus on specific tokens regardless of distance
  • SSD uses exponential decay → creates smooth patterns where recent tokens are weighted more heavily

This is the Structured State Space Duality (SSD): SSMs and attention are two ends of a spectrum, connected by the semi-separable matrix structure. One uses content-based routing (softmax), the other uses position-based decay (exponential). Both compute y=Mu\mathbf{y} = M \cdot \mathbf{u} where MM is structured differently.

Chunked Computation: Tensor Cores Meet RNNs

The matrix MM has size L×LL \times L — materializing it fully would be O(N2)O(N^2), defeating the purpose. Mamba-2’s key optimization is chunked processing:

  1. Divide the sequence into chunks of size PP (e.g., P=64P = 64)
  2. Within each chunk: Use the matrix form — a small P×PP \times P matmul → tensor cores!
  3. Between chunks: Pass state via recurrence — like an RNN
   def ssm_chunked(x, A, B, C, chunk_size=64):
    """Chunked SSD computation.

    Within chunk: matrix form (tensor core acceleration)
    Between chunks: state passing (O(1) per chunk boundary)
    Total: O(N/P · P²) = O(N·P) — for fixed P, this is O(N)

    The small P×P matrices fit in registers/shared memory,
    and matrix multiplication is the most optimized operation
    on H100/H200 GPUs via tensor cores.
    """
    batch, L, d_inner = x.shape
    d_state = A.shape[-1]
    state = torch.zeros(batch, d_inner, d_state)

    all_outputs = []
    for start in range(0, L, chunk_size):
        end = min(start + chunk_size, L)
        P = end - start

        x_chunk = x[:, start:end, :]        # (batch, P, d_inner)
        B_chunk = B[:, start:end, :]         # (batch, P, d_state)
        C_chunk = C[:, start:end, :]         # (batch, P, d_state)

        # BUILD the P×P semi-separable matrix for this chunk
        # M_chunk[i,j] = C_chunk[i] @ diag(A^{i-j}) @ B_chunk[j]
        # This is a small matmul → tensor cores handle it efficiently

        # APPLY: y_chunk = M_chunk @ x_chunk
        # PLUS: contribution from inter-chunk state

        # UPDATE state for next chunk
        # state encodes the effect of all tokens before this chunk

        chunk_outputs = []
        for t in range(P):
            state = A * state + B_chunk[:, t, :].unsqueeze(1) * x_chunk[:, t, :].unsqueeze(-1)
            y_t = (C_chunk[:, t, :].unsqueeze(1) * state).sum(dim=-1)
            chunk_outputs.append(y_t)

        all_outputs.append(torch.stack(chunk_outputs, dim=1))

    return torch.cat(all_outputs, dim=1), state

Why This Is a Breakthrough

Mamba-1 (parallel scan):

  • Custom CUDA kernels hand-tuned for each GPU architecture
  • Can’t leverage tensor cores (the fastest operation on modern GPUs)
  • Hard to maintain, debug, and port
  • Each hardware generation needs new kernel tuning

Mamba-2 (chunked SSD):

  • Standard matrix multiplication within chunks → tensor cores
  • Uses existing highly-tuned BLAS libraries (cuBLAS, etc.)
  • More portable across hardware (tensor cores are universal)
  • Easier to understand, maintain, and optimize
  • Enables Triton implementations (more accessible than raw CUDA)

The practical speedup: 2-8x higher throughput over Mamba-1 on the same hardware.

The Multi-Head SSD Layer

Like attention, Mamba-2 uses multiple heads, each with its own SSD computation:

   class SSD(nn.Module):
    """Multi-head Structured State Space Duality layer.

    Analogous to multi-head attention:
    - Input → project to B (key), C (query), x (value)
    - Each head: independent SSD computation
    - Concatenate heads → output projection
    """
    def __init__(self, d_model, n_heads=8, d_state=64, chunk_size=64):
        super().__init__()
        self.head_dim = d_model // n_heads
        self.n_heads = n_heads
        self.d_state = d_state
        self.chunk_size = chunk_size

        # Projections (analogous to Q, K, V in attention)
        self.in_proj = nn.Linear(d_model, d_model * 3, bias=False)

        # Decay parameter per head — learned
        self.A_log = nn.Parameter(torch.randn(n_heads, self.head_dim))

        self.out_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        batch, L, _ = x.shape

        # Project: x → (B-like, C-like, input)
        x_inner, B_proj, C_proj = self.in_proj(x).chunk(3, dim=-1)

        # Reshape for multi-head: (batch, L, n_heads, head_dim)
        x_inner = x_inner.view(batch, L, self.n_heads, self.head_dim)
        B_proj = B_proj.view(batch, L, self.n_heads, self.head_dim)
        C_proj = C_proj.view(batch, L, self.n_heads, self.head_dim)

        # Decay in (0, 1) for stability
        A = torch.sigmoid(self.A_log)  # (n_heads, head_dim)

        # Chunked SSD computation per head
        # (uses tensor-core-friendly matmul within chunks)
        out = self._chunked_ssd(x_inner, B_proj, C_proj, A)

        # Reshape back and project
        out = out.reshape(batch, L, -1)
        return self.out_proj(out)

State Size: The Decisive Advantage

For inference, the state comparison is dramatic:

ModelState Size per LayerAt 100K context (7B model)
Transformer (KV cache)2×nheads×dhead×L2 \times n_{\text{heads}} \times d_{\text{head}} \times L~100M values (~200 MB)
Mamba / Mamba-2dinner×dstated_{\text{inner}} \times d_{\text{state}}~16K values (~32 KB)
Ratio~6,000x reduction

For a 32-layer 7B model at 100K context:

  • Transformer KV cache: 32 × 200MB = ~6.4 GB
  • Mamba state: 32 × 32KB = ~1 MB

This translates directly to:

  • More sequences served per GPU — batch size can be much larger
  • Longer contexts without OOM — 1M+ tokens is feasible
  • Lower latency — less memory to read/write per step

The Unification Insight

The SSD framework reveals that four seemingly different architectures are all computing variations of structured matrix multiplication:

   RNN:              x_t = f(W·x_{t-1} + U·u_t)       — nonlinear recurrence
Linear Attention: S_t = S_{t-1} + K_t^T @ V_t       — additive accumulation (rank-1)
S4:               x_t = Ā·x_{t-1} + B̄·u_t          — linear recurrence (fixed)
Mamba:            x_t = Ā_t·x_{t-1} + B̄_t·u_t      — linear recurrence (selective)

All can be viewed as either:

  • Recurrence: sequential, O(1)O(1) per step
  • Matrix multiplication: parallel, uses hardware acceleration

The choice between them is about the structure of the mixing matrix:

  • Full attention: Dense, content-dependent matrix (softmax over QKQK^\top)
  • Linear attention: Rank-1 outer product updates (KVK^\top \otimes V)
  • SSM: Semi-separable matrix with exponential decay (CAi-jBC \cdot A^{i\text{-}j} \cdot B)

This unification suggests that future architectures may freely mix these components, choosing the right structure for each layer. As we’ll see, that’s exactly what hybrid architectures do.


5. Linear Attention — The Associativity Bridge

The Core Trick

Standard attention computes:

Attention(Q,K,V)=softmax ⁣(QKd)VO(N2d)\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}}\right) V \qquad O(N^2 d)

The softmax is the bottleneck. It operates on the N×NN \times N attention matrix, coupling all positions simultaneously and making the computation inherently O(N2)O(N^2).

Linear attention replaces softmax with a separable feature map ϕ\phi:

LinearAttn(Q,K,V)=ϕ(Q)(ϕ(K)V)ϕ(Q)ϕ(K)1\text{LinearAttn}(Q, K, V) = \frac{\phi(Q) \cdot \bigl(\phi(K)^\top V\bigr)}{\phi(Q) \cdot \phi(K)^\top \mathbf{1}}

The key insight is matrix associativity: we can choose the multiplication order.

Standard orderO(N2)O(N^2): compute the N×NN \times N matrix first:

output=(ϕ(Q)ϕ(K))N×N — quadratic!V\text{output} = \underbrace{\bigl(\phi(Q) \cdot \phi(K)^\top\bigr)}_{N \times N \text{ — quadratic!}} \cdot V

Rearranged orderO(Nd2)O(Nd^2): compute the d×dd \times d matrix first:

output=ϕ(Q)(ϕ(K)V)d×d — linear in N!\text{output} = \phi(Q) \cdot \underbrace{\bigl(\phi(K)^\top \cdot V\bigr)}_{d \times d \text{ — linear in } N!}

This reordering is only valid because we removed softmax. The softmax normalization depends on all keys simultaneously (the denominator is a sum over all positions), breaking separability. By replacing softmax with a pointwise feature map, each position’s contribution becomes independent, enabling the rearrangement.

Two Equivalent Modes

Parallel Mode (Training): O(Nd2)O(Nd^2)

Non-causal (bidirectional) version — straightforward:

   def linear_attention_parallel(Q, K, V, feature_map, eps=1e-6):
    """For training: process all tokens at once"""
    Q_prime = feature_map(Q)    # (batch, heads, seq, head_dim)
    K_prime = feature_map(K)

    # Key-Value aggregation: K'^T @ V → (batch, heads, head_dim, head_dim)
    # This is the d×d matrix — computed once, applied to all queries
    KV = torch.einsum('bhnd,bhne->bhde', K_prime, V)

    # Normalization denominator
    Z = K_prime.sum(dim=2)      # (batch, heads, head_dim)

    # Output: Q' @ KV
    numerator = torch.einsum('bhnd,bhde->bhne', Q_prime, KV)
    denominator = torch.einsum('bhnd,bhd->bhn', Q_prime, Z).unsqueeze(-1)

    return numerator / (denominator + eps)

Causal (autoregressive) version — needs cumulative sums:

   def causal_linear_attention(Q, K, V, feature_map, eps=1e-6):
    """Causal version: each position only attends to past positions.
    Uses cumulative sum to build running state.
    """
    Q_prime = feature_map(Q)
    K_prime = feature_map(K)

    # Per-position KV outer products
    KV = torch.einsum('bhnd,bhne->bhnde', K_prime, V)

    # Cumulative sum for causality: S_t = Σ_{i≤t} KV_i
    S_cumsum = torch.cumsum(KV, dim=2)     # Running state
    z_cumsum = torch.cumsum(K_prime, dim=2) # Running normalizer

    # Query against cumulative state at each position
    numerator = torch.einsum('bhnd,bhnde->bhne', Q_prime, S_cumsum)
    denominator = torch.einsum('bhnd,bhnd->bhn', Q_prime, z_cumsum).unsqueeze(-1)

    return numerator / (denominator + eps)

The causal version materializes the running state at every position via cumsum, which is O(N×d2)O(N \times d^2) total.

Recurrent Mode (Inference): O(d2)O(d^2) per step

   def linear_attention_step(q_t, k_t, v_t, S, z, feature_map, eps=1e-6):
    """Single step for autoregressive generation: O(d²) per token.

    S: (batch, heads, head_dim, head_dim) — accumulated key-value outer products
    z: (batch, heads, head_dim)           — accumulated normalizer

    The state S is a d×d matrix — fixed size regardless of sequence length.
    """
    q = feature_map(q_t)    # (batch, heads, head_dim)
    k = feature_map(k_t)

    # Update state: accumulate key-value outer product
    S = S + torch.einsum('bhd,bhe->bhde', k, v_t)  # O(d²)
    z = z + k                                        # O(d)

    # Query against accumulated state
    numerator = torch.einsum('bhd,bhde->bhe', q, S)  # O(d²)
    denominator = torch.einsum('bhd,bhd->bh', q, z).unsqueeze(-1)

    y = numerator / (denominator + eps)
    return y, S, z

The state (S,z)(S, z) has fixed size:

  • S: (heads, head_dim, head_dim) — the accumulated key-value outer products, a d×dd \times d matrix
  • z: (heads, head_dim) — the normalizer

This is O(d2)O(d^2) regardless of how many tokens have been processed. Compare with the KV cache in standard attention: O(N×d)O(N \times d), which grows with every token.

Feature Maps: What Replaces Softmax?

The feature map φ must produce non-negative outputs (since attention weights should be non-negative). Common choices:

ELU + 1 (most common in practice):

   def feature_map_elu(x):
    return F.elu(x) + 1    # Always >= 0, smooth, differentiable everywhere

ReLU (simple but sparse):

   def feature_map_relu(x):
    return F.relu(x)        # Simple, but many zeros → sparse attention

Random Fourier Features (can approximate softmax kernel):

   def feature_map_rff(x, omega):
    """Random Fourier Features approximate exp(q·k/√d).
    This is the kernel trick: softmax attention ≈ φ(Q)φ(K)^T
    where φ maps to a random feature space.
    """
    projected = x @ omega    # Random projection
    return torch.cat([torch.cos(projected), torch.sin(projected)], dim=-1) / math.sqrt(omega.shape[1])

RFF is theoretically interesting because it can approximate the softmax kernel, but in practice ELU+1 is the most common due to simplicity and non-sparsity. The choice of feature map significantly affects the quality of attention patterns and is an active area of research.

Connection to SSMs: The Common Thread

Linear attention and SSMs are both linear recurrences, but with fundamentally different update rules:

Linear Attention — additive update, no decay:

St=St1+KtVt,zt=zt1+Kt,yt=QtStQtztS_t = S_{t-1} + K_t^\top \otimes V_t, \quad z_t = z_{t-1} + K_t, \quad y_t = \frac{Q_t \cdot S_t}{Q_t \cdot z_t}

SSM (Mamba) — multiplicative decay + additive input:

xt=Atxt1+Btut,yt=Ctxtx_t = A_t \cdot x_{t-1} + B_t \cdot u_t, \quad y_t = C_t \cdot x_t

The critical difference is the decay factor AA:

  • Linear attention accumulates without forgetting: SS grows in magnitude forever. Every token ever seen contributes equally to the state (weighted by its key). This makes it excellent for global aggregation tasks but problematic for long sequences where irrelevant context should be forgotten.
  • SSMs can selectively forget: The decay A<1A < 1 means old information gradually fades, making room for new information. This is crucial for long sequences where the model must track evolving context.

The lack of decay in vanilla linear attention is actually a significant weakness. Consider a 100K-token sequence where only the last 1K tokens are relevant: linear attention still carries all 100K tokens’ worth of information in S, diluting the signal. Mamba’s exponential decay naturally focuses on recent context.

The Tradeoffs: Attention vs. Linear Attention

AspectStandard AttentionLinear Attention
ExpressivenessFull rank-NN patternsLimited to rank-dd patterns
TrainingO(N2d)O(N^2 d)O(Nd2)O(Nd^2)
InferenceO(Nd)O(Nd) per stepO(d2)O(d^2) per step
Sharp attentionYes (softmax peaks)No (smooth feature products)
RecallExcellent (exact lookup)Weaker (compressed into d×dd \times d state)
ForgettingImplicit (via softmax normalization)None (accumulates forever)
Long sequencesExpensive but preciseCheap but lossy

The reduced expressiveness is the main weakness: linear attention can only represent attention patterns that are separable into ϕ(Q)ϕ(K)\phi(Q) \cdot \phi(K)^\top products. The softmax in standard attention creates much sharper, more selective patterns — it can place almost all weight on a single token. Linear attention patterns are inherently smooth.

This motivates the next generation of linear attention variants — GLA, RWKV, RetNet — which add gating and decay to recover some of this lost expressiveness. They bridge the gap between pure linear attention (no decay, no gating) and full SSMs (learned decay, input-dependent gating).


6. The Landscape: GLA, RWKV, RetNet & Beyond

The success of S4 and Mamba triggered an explosion of subquadratic architectures. The field can be organized into four broad families, all sharing the “train parallel, infer recurrent” paradigm but differing in their update rules and gating mechanisms.

A Unified View: Update Rules and Gating

Every subquadratic architecture maintains a hidden state and updates it with each new token. The key design axes are:

Update rule — how does the state change?

  • Outer-product additive: SS+kvS \leftarrow S + k^\top \otimes v (linear attention family)
  • Delta rule: SS+k(vSk)S \leftarrow S + k^\top \otimes (v - S \cdot k) (error-correcting update)
  • SSM recurrence: xAx+Bux \leftarrow A \cdot x + B \cdot u (decay + input)

Gating — how is the update modulated?

  • No gating: Fixed update strength (vanilla linear attention)
  • Scalar gating: Single decay factor per dimension (S4, RetNet)
  • Vector gating: Per-dimension decay, input-dependent (Mamba, GLA)
  • Matrix gating: Full state rotation (theoretical limit, too expensive in practice)

The progression from S4 through Mamba to GLA can be understood as moving along both axes: from simple update rules with no gating, toward richer update rules with input-dependent gating.

Gated Linear Attention (GLA)

Key idea: Add a per-dimension, input-dependent gating (decay) to linear attention.

Standard linear attention accumulates state without decay:

St=St1+ktvtS_t = S_{t-1} + k_t^\top \otimes v_t

GLA adds input-dependent gating:

St=GtSt1+ktvtS_t = G_t \odot S_{t-1} + k_t^\top \otimes v_t

where Gt=diag(gt)G_t = \operatorname{diag}(g_t) is a diagonal gate computed from the input: gt=σ(Wgxt)g_t = \sigma(W_g \cdot x_t).

This bridges linear attention and Mamba: like linear attention, it uses key-value outer products for state updates; like Mamba, it has input-dependent decay. GLA achieves competitive performance with Mamba while being more naturally expressed as attention operations.

Training: Uses a chunk-wise parallel algorithm similar to Mamba-2’s chunked computation. Within each chunk, the gated updates can be computed as masked matrix multiplications, enabling tensor core utilization. This makes GLA one of the most hardware-efficient subquadratic architectures.

RWKV (v4 → v7): Reinventing RNNs from Scratch

RWKV (pronounced “RWA-KV”) is a linear-complexity architecture that evolved through seven major versions, developed largely by a single researcher (Bo Peng) and the open-source community — yet achieving results competitive with well-funded lab efforts.

RWKV-4/5: Channel-mixing + time-mixing with exponential decay:

wkvt=i=1t1e(t1i)w+kivi+eu+ktvtwkv_t = \sum_{i=1}^{t-1} e^{-(t-1-i) \cdot w + k_i} \cdot v_i + e^{u + k_t} \cdot v_t

RWKV-6: Added data-dependent decay (like Mamba’s selection):

wt=w0+αLinear(xt)w_t = w_0 + \alpha \cdot \text{Linear}(x_t)

This is the key evolutionary step: RWKV-6’s input-dependent decay is analogous to Mamba’s selective mechanism. Both recognize that fixed dynamics are insufficient — the model needs to adapt its memory retention based on content.

RWKV-7 (Eagle/Finch): Matrix-valued gating and improved state update. The 3B parameter RWKV-7 World model reportedly outperforms Llama 3.2 and Qwen2.5 on several benchmarks — a remarkable result for a community-driven project.

RetNet: Retentive Networks (Microsoft)

RetNet introduces a “retention” mechanism that makes the parallel-recurrent duality particularly clean and explicit:

Retention (single head):

Retention(X)=(QKD)V\text{Retention}(X) = (QK^\top \odot D) \cdot V

where DD is a causal decay matrix:

D[i,j]={γijfor ij0for i<jD[i,j] = \begin{cases} \gamma^{i-j} & \text{for } i \geq j \\ 0 & \text{for } i < j \end{cases}

This is strikingly similar to Mamba-2’s semi-separable matrix, but with a fixed scalar decay γ\gamma rather than input-dependent or matrix-valued decay. The simplicity is the point: RetNet cleanly demonstrates the three computation modes without the complexity of selective parameters.

RetNet supports three modes:

  • Parallel: Full matrix computation for training — compute QKQK^\top, apply decay mask DD, multiply by VV
  • Recurrent: O(1)O(1) per step for inference — maintain running state St=γSt-1+KtVtS_t = \gamma \cdot S_{t\text{-}1} + K_t^\top \otimes V_t
  • Chunkwise: Hybrid for longer sequences — parallel within chunks, recurrent between chunks

RetNet’s contribution is primarily conceptual clarity: it makes the parallel-recurrent duality very explicit and provides a clean framework for thinking about the design space.

HGRN2: Hierarchically Gated Linear RNNs

HGRN2 introduces an outer-product state expansion mechanism:

  • The original HGRN uses element-wise gated recurrence (small state, limited capacity)
  • HGRN2 replaces element-wise products with outer products (larger state, more expressive)
  • Diagonalizes the forget gate for efficient hidden state updates

This gives HGRN2 a linear attention interpretation while maintaining RNN efficiency. The 3B model matches Mamba and Transformer baselines on language modeling. Accepted at COLM 2024.

Griffin and Hawk (Google)

Google’s contribution to the subquadratic landscape:

Hawk: Pure gated linear recurrence using the RG-LRU (Real-Gated Linear Recurrent Unit):

xt=atxt1+1at2(btinputt)x_t = a_t \odot x_{t-1} + \sqrt{1 - a_t^2} \odot (b_t \otimes \text{input}_t)

where ata_t is a learnable input-dependent gate. The 1at2\sqrt{1 - a_t^{2}} factor ensures numerical stability — it normalizes the update so the state magnitude stays bounded.

Griffin: Linear recurrence + local sliding window attention.

Griffin’s key finding: the hybrid (recurrence + local attention) significantly outperforms pure recurrence on recall-intensive tasks. This anticipates the broader trend toward hybrid architectures that we’ll explore in the next section.

Based: Linear Attention with Sliding Window (Stanford)

Based takes a practical approach: combine linear attention for global context with a sliding window for local detail.

The insight: linear attention is good at capturing long-range dependencies but bad at sharp local patterns. Standard attention is great locally but expensive globally. So use both:

   output = SlidingWindowAttention(x, window=512) + LinearAttention(x)

This is one of the earliest “hybrid” designs, operating at the mechanism level (combining two attention types) rather than the layer level (alternating different layer types).

The Taxonomy: Four Generations

Based on survey literature (2025), subquadratic architectures can be roughly grouped into generations:

GenerationKey ModelsInnovation
Gen 1: Fixed recurrenceS4, Linear Attention, RetNetDual forms (parallel/recurrent), but fixed dynamics
Gen 2: Selective/gatedMamba, GLA, RWKV-6Input-dependent parameters, content-aware processing
Gen 3: Hardware-awareMamba-2, GLA-2Tensor core utilization via chunked computation
Gen 4: HybridJamba, Zamba, Griffin, SambaMix attention + efficient layers for best of both

The progression reveals a clear pattern: pure subquadratic models converge toward incorporating some attention, while attention-based models converge toward incorporating efficient recurrence. The middle ground — hybrid architectures — is where the field is consolidating.

Performance Reality Check (2025)

From recent surveys:

Where subquadratic models excel:

  • Long-context tasks (>8K tokens) — the efficiency advantage is decisive
  • Streaming / real-time applications — O(1)O(1) per-token generation
  • Memory-constrained inference — fixed state instead of growing KV cache
  • Throughput-critical serving — higher batch sizes possible

Where attention still dominates:

  • Recall-intensive tasks (needle-in-a-haystack) — exact lookup from arbitrary positions
  • Tasks requiring precise long-range retrieval — “what was the value on line 47?”
  • Large-scale models (14B+) — no pure subquadratic models exist at this scale yet
  • Tasks where the quality-efficiency tradeoff doesn’t matter

The emerging consensus: At the 7B scale, pure SSM models (Falcon Mamba 7B) are competitive with Transformers. But at 14-70B, only hybrid architectures (Jamba, Griffin) remain competitive — no pure subquadratic model has demonstrated transformer-equivalent quality above 14B parameters.

This strongly suggests that the future isn’t “SSM vs. Attention” but “SSM + Attention” — hybrid architectures that use each mechanism where it’s most effective.


7. Hybrid Architectures — The Best of Both Worlds

Why Hybrid?

The recall problem is fundamental: compressing the entire sequence history into a fixed-size state is inherently lossy. No matter how clever the compression (HiPPO, selective scanning, gated recurrence), there exist tasks where you need to look up a specific token from arbitrarily far in the past.

Standard attention solves this trivially — the KV cache retains everything, and softmax can place arbitrarily sharp weight on any position. But it costs O(N)O(N) per token at inference.

The pragmatic solution: use attention where you need recall, and efficient layers everywhere else.

The Design Space

Hybrid architectures vary along several dimensions:

DimensionOptions
RatioWhat fraction of layers use attention? (1/8, 1/4, 1/3, 1/2)
PatternInterleaved? Attention at bottom/top? Every-K-th layer?
Attention typeFull attention? Sliding window? Grouped-query?
Efficient layerMamba? Mamba-2? Linear attention? Gated linear recurrence?
IntegrationSequential (stacked layers)? Parallel (summed outputs)?

Jamba (AI21 Labs)

Jamba (March 2024) was one of the first production-grade hybrid architectures.

Architecture:

  • Interleaves Mamba layers with Transformer attention layers at 1:7 ratio — one attention layer for every seven Mamba layers
  • Adds Mixture of Experts (MoE) for capacity: 52B total parameters, 12B active
  • 256K context window support
  • Matches Mixtral 8x7B quality while fitting on a single 80GB GPU
  • 3x throughput improvement over comparable Transformers at long contexts

Why 1:7? The attention layers serve as periodic “recall checkpoints” — they allow the model to perform precise retrieval from the full context. Seven Mamba layers between checkpoints provide efficient processing with exponentially decaying but still-useful memory. Empirically, more frequent attention layers show diminishing returns for most tasks.

Zamba (Zyphra)

Architecture:

  • Mamba backbone with shared attention layers — a single attention module with shared weights is reused at multiple positions
  • Drastically reduced parameter count from attention while maintaining recall capability

Key innovation: Weight sharing for attention layers. Since the attention layers mainly serve as “recall refreshers,” they don’t need unique parameters at each position. The same attention module can re-attend at different depths, providing recall without the parameter overhead of multiple distinct attention layers.

Zamba2-7B reportedly achieves state-of-the-art quality among 7B models, outperforming both Llama 3.1 8B and Mistral 7B on many benchmarks.

Samba (Microsoft)

Architecture:

  • Alternates Mamba layers with sliding window attention layers
  • Sliding window size is relatively small (e.g., 2048 tokens)
  • Mamba handles long-range dependencies; sliding window handles local patterns

Why sliding window? Full attention would negate the efficiency gains. But a small sliding window is cheap — O(W)O(W) per token where WW is the window size — and captures the local detail that SSMs often miss. The combination is powerful: Mamba provides the long-range context highway, and sliding window attention provides the sharp local pattern recognition.

Results: Samba and RWKV-7 significantly outperform full-attention Llama 3.2 and Qwen2.5 on several benchmarks at the 7B scale.

Falcon H1 (TII)

Architecture:

TII’s evolution from Falcon Mamba (pure SSM) to Falcon H1 (hybrid) mirrors the field’s trajectory: pure SSMs are competitive but hybrids push the frontier further.

RecurrentGemma (Google)

Architecture:

  • Uses Griffin (gated linear recurrence + local attention) as its backbone
  • Linear recurrence layers for most of the processing
  • Local sliding window attention layers for recall
  • Available at 2B and 9B scales

Results show it’s competitive with Gemma at the same scale while offering better inference efficiency for long sequences. Google’s willingness to ship a recurrent-hybrid architecture signals confidence in the approach.

Design Principles That Have Emerged

From studying these hybrid architectures, several consistent principles emerge:

(1). Attention Ratio: Less Than You’d Think

Most successful hybrids use only 10-25% attention layers. This suggests:

  • A few attention layers provide sufficient recall capability
  • The majority of “thinking” can happen in efficient recurrent layers
  • Diminishing returns from adding more attention layers — the first attention layer provides most of the recall benefit

(2). Placement Matters

Where attention layers go affects performance:

  • Interleaved (every K-th layer): Most common, provides periodic recall refresh
  • Bottom-heavy: Some evidence that early layers benefit more from attention (captures input structure)
  • Top-heavy: The last few layers may need attention for final output quality (sharper predictions)

(3). The “State Refresh” Pattern

The hybrid architecture can be understood through a “state refresh” metaphor:

   [SSM] → [SSM] → [SSM] → ... → [Attention] → [SSM] → [SSM] → ...
  efficient       efficient         recall         efficient
  processing      processing        refresh        processing

The SSM layers process the sequence efficiently, building up a compressed representation. But this compression is lossy — some information is inevitably lost or blurred. The attention layer acts as a “state refresh” — it allows the model to precisely retrieve information that may have been compressed or lost in the SSM state. The subsequent SSM layers then process this refreshed, more accurate information.

(4). Attention Type: Full vs. Sliding Window

  • Full attention (Jamba): Maximum recall, but O(N)O(N) cost at each attention layer
  • Sliding window (Samba): Cheaper (O(W)O(W) cost), but only local recall
  • Grouped-query (Falcon H1): Reduced KV cache while maintaining quality

The choice depends on whether you need global recall (exact retrieval from anywhere in the sequence) or just local recall (sharp pattern matching in recent context).

Comparison Table

ModelEfficient LayerAttention TypeRatioScaleNotable
JambaMambaFull1:752B (12B active)256K context
Zamba2MambaShared fullWeight-shared7BParameter efficient
SambaMambaSliding window~1:17BStrong benchmarks
Falcon H1Mamba-2 (SSD)GQAMixed7BReasoning focus
GriffinGated LRUSliding windowMixed2-14BGoogle’s approach
RecurrentGemmaGriffinLocalMixed2-9BProduction ready

The Inference Advantage: Concrete Numbers

Even with some attention layers, hybrids offer significant inference advantages:

Pure Transformer (all 32 layers are attention):

  • Per-token KV cache cost: O(L×32×d)O(L \times 32 \times d) — every layer stores K, V
  • At 100K context: ~32 × 2 × 100K × d_head × n_heads ≈ 6.4 GB KV cache

Hybrid (25% attention = 8 attention layers, 75% SSM = 24 layers):

  • KV cache: only for 8 layers → 8 × 2 × 100K × d_head × n_heads ≈ 1.6 GB
  • SSM state for 24 layers: 24 × d_inner × d_state ≈ negligible (~1 MB)
  • Total: ~1.6 GB — a 4x reduction in memory

The SSM state is negligible compared to the KV cache at any reasonable context length. The memory savings translate directly to larger batch sizes and lower latency.

When to Use What

ScenarioRecommended Architecture
Short context (<4K), quality-criticalPure Transformer
Medium context (4K-32K), balancedHybrid (Jamba-style)
Long context (32K-256K)Hybrid with sliding window (Samba-style)
Streaming / real-timePure SSM or SSM-heavy hybrid
Memory-constrained (edge, mobile)Pure SSM (Falcon Mamba)
Maximum recall needed (RAG, tool use)Transformer or attention-heavy hybrid
Agentic RL with long rolloutsHybrid — SSM for speed, attention for tool-use recall

8. Practical Implications and Choosing the Right Architecture

The Rollout Bottleneck, Revisited

The SSM revolution isn’t just academic architecture research — it directly impacts the most demanding workload in modern AI: agentic reinforcement learning training.

In another RL infrastructure blog post of mine, it has been documented that 80-90% of RL training time is spent on sample generation (rollouts). Let’s make the speedup concrete:

   Traditional Transformer rollout (50K tokens):
  Token 1:     attend to 1 token      → 1 unit of work
  Token 1000:  attend to 1000 tokens  → 1,000 units of work
  Token 50K:   attend to 50K tokens   → 50,000 units of work

  Total: Σ_{i=1}^{50K} i ≈ 1.25 billion operations
   SSM rollout (Mamba/Mamba-2, 50K tokens):
  Token 1:     O(d²)  ≈ 1 unit of work
  Token 1000:  O(d²)  ≈ 1 unit of work
  Token 50K:   O(d²)  ≈ 1 unit of work

  Total: 50K × d² ≈ 50K units of work

That’s a \sim25,000x reduction in total compute for the generation phase. Even after accounting for SSM layers being more expensive per-operation than individual attention steps, and the fact that real models have many layers and the comparison isn’t perfectly apples-to-apples, the speedup for long-context generation is transformative.

Concrete Impact on Agentic RL Training

Consider a Slime-style agentic RL setup where:

  • The agent interacts with a code execution environment
  • Average rollout length: 30K tokens (prompt + multiple tool calls + reasoning)
  • Training requires 100K rollouts per update

With a pure Transformer (7B scale):

  • KV cache per rollout: ~4 GB
  • Maximum batch size limited by GPU memory → fewer parallel rollouts
  • Generation dominates training wall-clock time

With a hybrid (75% SSM, 25% attention):

  • State per rollout: ~1 GB (4x less KV cache + small SSM state)
  • 4x larger batch size → better GPU utilization
  • Generation time scales near-linearly with rollout length instead of quadratically

This isn’t hypothetical — it’s why companies like Cartesia are building SSM-based models specifically for real-time, long-context applications, and why Falcon H1 uses a Mamba-2 hybrid backbone.

The Recall-Efficiency Frontier

The central tension in the field can be visualized as a Pareto frontier:

   Quality (recall) ↑
                 |   * Transformer (full attention)
                 |  * Hybrid (25% attention)
                 | * Hybrid (10% attention)
                 |
                 |        * Mamba-2
                 |       * Mamba
                 |
                 |                    * S4
                 |
                 +------------------------→ Efficiency (tokens/sec)

Each architecture occupies a point on this frontier. The field’s progress has been to push the frontier outward — getting more quality for the same efficiency, or more efficiency for the same quality.

Key developments pushing the frontier:

  • Mamba (2023): Selective parameters dramatically improved quality for pure SSMs
  • Mamba-2 (2024): Hardware-efficient SSD pushed throughput without quality loss
  • Hybrid architectures (2024-2025): Mixing in minimal attention recovered recall capability
  • GLA, RWKV-7 (2024-2025): Better gating/update rules closed the quality gap further

Decision Framework

   1. What is your maximum acceptable inference latency per token?
   └─ Unconstrained → Pure Transformer (simplest, best quality)
   └─ Latency-sensitive → Continue to 2

2. What is your typical context length?
   └─ &lt;4K tokens → Pure Transformer (quadratic cost is manageable)
   └─ 4K-32K → Hybrid architecture (Jamba, Falcon H1)
   └─ &gt;32K → SSM-heavy hybrid or pure SSM

3. Do you need precise long-range retrieval?
   └─ Yes (RAG, tool use, code reference) → Hybrid with full attention layers
   └─ No (summarization, open-ended generation) → Pure SSM may suffice

4. What are your memory constraints?
   └─ Limited (edge, mobile) → Pure SSM
   └─ Standard (A100/H100) → Hybrid
   └─ Abundant (multi-GPU) → Whatever maximizes quality

Ecosystem Maturity

One practical consideration that often dominates architecture choices: software ecosystem maturity.

ArchitectureFrameworksPre-trained ModelsCustom Kernel Support
TransformerPyTorch, JAX, TensorRT, vLLMAbundant (Llama, GPT, Qwen, etc.)Flash Attention, PagedAttention
Mambamamba-ssm (CUDA), TritonMamba (130M-2.8B), Falcon Mamba 7BMamba CUDA kernels
Mamba-2mamba-ssm v2LimitedSSD Triton kernels
HybridFramework-dependentJamba, Zamba, Falcon H1Mixed (attention + SSM kernels)

The Transformer ecosystem is vastly more mature. Choosing an SSM or hybrid architecture means fewer pre-trained checkpoints, less optimized serving infrastructure, and more custom engineering. This gap is closing — particularly as frameworks like vLLM add Mamba support — but remains significant in 2025.


9. The Road Ahead

Open Questions

(1). Scaling Beyond 14B

No pure subquadratic model has been trained above ~7B parameters. Can the SSM advantages hold at 70B+? The theory suggests yes — the efficiency gains become even more pronounced at scale, since the O(N2)O(N^2) vs O(N)O(N) gap widens. But no one has demonstrated it empirically. The cost of training at this scale is prohibitive for most labs, and the companies that can afford it (OpenAI, Google, Anthropic, Meta) have invested heavily in Transformer infrastructure.

(2). Adaptive State Sizes

Current SSMs use fixed-size states (d_inner × d_state). Could we have adaptive state sizes that grow when the content demands it? Log-linear attention is one approach: the state grows logarithmically with sequence length, maintaining a Pareto-optimal position between fixed-state RNNs and full-state attention.

(3). The Role of Attention in Reasoning

There’s emerging evidence that attention may be particularly important for chain-of-thought reasoning, where the model needs to refer back to intermediate results. Consider a 10-step math derivation: the model must reference step 3’s result when computing step 7. SSMs’ exponentially decaying state may struggle here, while attention can precisely retrieve any previous step.

If this is true, the optimal hybrid ratio may be task-dependent:

  • Math reasoning: more attention needed (precise reference to intermediate steps)
  • Code generation: SSM-heavy for speed, attention at verification points
  • Conversational: mostly SSM (recent context dominates)
  • Long-document QA: hybrid with full attention for retrieval

(4). Hardware Co-Design

Mamba-2’s success came partly from aligning the computation with tensor cores. Future architectures may be designed in tandem with hardware:

  • GPUs (H100/H200): Favor large matrix multiplications → chunked SSD fits well
  • TPUs: Favor different operation patterns → may enable different SSM designs
  • Emerging architectures (Groq, Cerebras): May favor different primitives entirely
  • The “best” architecture may be hardware-dependent

(5). Convergence with Control Theory

SSMs have deep roots in control theory and signal processing. As the field matures, we may see more cross-pollination:

  • Better initialization from control-theoretic analysis (beyond HiPPO)
  • Stability guarantees from Lyapunov theory (provably stable recurrences)
  • Optimal filtering connections (Kalman filter → SSM with learned noise model)
  • Observer design principles for state estimation in noisy sequences

The Arc of Progress

   2017: Transformer   — O(N²), parallelizable, powerful
2021: S4            — O(N log N)/O(1), dual forms, but time-invariant
2023: Mamba         — O(N)/O(1), content-aware selection
2024: Mamba-2       — O(N)/O(1), tensor cores, SSM=Attention duality
2024: Hybrids       — O(N) amortized, best of both worlds
2025: Convergence   — The architecture is secondary;
                      the data, training recipe, and scale matter more

The trajectory suggests that within a few years, the attention-vs-SSM debate will be as settled as the CNN-vs-RNN debate in NLP was by 2018. The answer, as always, will be: use the right tool for each layer.

The winning architectures will be those that make this choice adaptively — attending precisely when recall demands it, flowing efficiently when it doesn’t. Not as a static design choice made once at architecture time, but potentially as a dynamic decision at each layer, for each sequence, at each position.

The question isn’t whether Transformers will be replaced. It’s whether each layer in a model needs to be a Transformer. Increasingly, the answer is no — and the layers that aren’t are getting faster every month.


References

Foundational Papers

Architecture Variants

Hybrid Architectures

Surveys

Implementations and Resources