Swapping GDN for GDN-2 on Qwen3.5

I remember coming across the Gated DeltaNet-2 paper a little while back and thinking, hey this would make a pretty good new Qwen3.5-like model architecture. However, one whole month later, no one has yet to meaningfully adopt this linear attention. This got me thinking: if I can't meaningfully train a new model from scratch, and GDN-2 is a generalization of GDN, can't I initialize GDN-2 to behave like GDN and fine-tune Qwen?

This is completely possible. Take a look at the gated delta rule:

S_t = α_t (I − β_t k_t k_tᵀ) S_{t−1} + β_t k_t v_tᵀ

The gated delta rule-2 (wonderful name, by the way) is not all that different. It decouples the original β gate into a new write (w) gate on the key axis and erase (b) gate on the value axis.

S_t = (I − k_t (b_t ⊙ k_t)ᵀ) D_t S_{t−1} + k_t (w_t ⊙ v_t)ᵀ

The paper literally mentions that in the special case where b_t = β_t · 1_{d_k} and w_t = β_t · 1_{d_v} this becomes Kimi Delta Attention (KDA). Furthermore, if you also set D_t = α_t I then we get the original Gated-DeltaNet. Claude did a good job simplifying this equation, so I'll paste it below:

S_t = (I − k_t (𝐛_t ⊙ k_t)ᵀ) D_t S_{t−1} + k_t (𝐰_t ⊙ v_t)ᵀ

(i) Substitute the erase gate. Since 𝐛_t ⊙ k_t = (β_t 1_{d_k}) ⊙ k_t = β_t k_t:
S_t = (I − k_t (β_t k_t)ᵀ) D_t S_{t−1} + k_t (𝐰_t ⊙ v_t)ᵀ

(ii) Pull the scalar β_t out of the outer product, k_t (β_t k_t)ᵀ = β_t k_t k_tᵀ:
S_t = (I − β_t k_t k_tᵀ) D_t S_{t−1} + k_t (𝐰_t ⊙ v_t)ᵀ

(iii) Substitute the decay. D_t = Diag(α_t 1_{d_k}) = α_t I:
S_t = (I − β_t k_t k_tᵀ) α_t S_{t−1} + k_t (𝐰_t ⊙ v_t)ᵀ

(iv) α_t is a scalar, so it commutes with the projector — pull it to the front:
S_t = α_t (I − β_t k_t k_tᵀ) S_{t−1} + k_t (𝐰_t ⊙ v_t)ᵀ

(v) Substitute the write gate. 𝐰_t ⊙ v_t = (β_t 1_{d_v}) ⊙ v_t = β_t v_t:
S_t = α_t (I − β_t k_t k_tᵀ) S_{t−1} + k_t (β_t v_t)ᵀ

(vi) Pull the scalar β_t out of the write term:
S_t = α_t (I − β_t k_t k_tᵀ) S_{t−1} + β_t k_t v_tᵀ

Calling all you Ornith-1.0s and Nex-N2s out there: why not swap GDN with GDN-2?

Experiment

So I rented a little node on vast.ai with a RTX PRO 4500 Blackwell for about a day to see if I could implement this. This was also my test run of Kimi K2.7 Code, which I have mixed feelings on after this experiment --- it's pretty good when you give it exact directions but struggles to really think outside the box, not a bad coding companion for when you still want to think a little bit. The source code for this project is here.

My goal was to implement Qwen3.5-0.8B, the smallest model of the series, with the GDN2 adapter. I wanted to implement the model both forward and backward so I could do some supervised fine-tuning on long-context to see if I can meaningfully learn better gate and decay values from existing weights. This is what I implemented/prepared:

  1. A torch model for gdn-2 based on the kernels in flash-linear-attention.
  2. A little weight loader to adapt gdn weights to gdn-2.
  3. Long-context instruction traces from zai-org/LongAlign-10k, Yukang/LongAlpaca-12k, wenbopan/anti-haystack.
  4. Model definitions for RULER.

I was half expecting there to be some numerical divergence just from the way the linear attention kernels are calculated, but I was not able to find anything meaningful. Even at 16k tokens on a variety of prompts the GDN2 kernel's top-1 does not diverge from GDN at all with the stock GDN weights. But that could also just be because the fla gdn2 kernel was built off of the gdn kernel so they effectively converge into the same thing, but I didn't take that close of a look.

For SFT I wanted to try both the really cheap method of fine tuning the modified gates only, as well as just SFT-ing the full model. For gates-only fine-tuning took less than 10 minutes. For the full model it still only took about half an hour. Putting anti-haystack in there is maybe a little bit questionable and would improve the performance of any model on a task like this, but I'm running another test with just normal GDN SFT to compare results. I will update when done.

For now, the results seen pretty promising for GDN-2 in-place. On niah_multivalue @ 131k tokens:

Model Score
Base GDN2 swap 92.5
Gate-only SFT checkpoint-101 92.19
Full SFT checkpoint-51 98.25

+5.75% improvement on needle in haystack compared to the base! We will see how the final base GDN model compares.

← Back to Home