Optimizing Language Models: Decoding Griffin’s Local Attention and Memory Efficiency

by Gating TechnologyMay 6th, 2025
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

Explores how Griffin’s local attention and recurrent layers outperform traditional Transformers, improving language modeling at scale and faster inference.
featured image - Optimizing Language Models: Decoding Griffin’s Local Attention and Memory Efficiency
Gating Technology HackerNoon profile picture
0-item

Authors:

(1) Soham De, Google DeepMind and with Equal contributions;

(2) Samuel L. Smith, Google DeepMind and with Equal contributions;

(3) Anushan Fernando, Google DeepMind and with Equal contributions;

(4) Aleksandar Botev, Google DeepMind and with Equal contributions;

(5) George Cristian-Muraru, Google DeepMind and with Equal contributions;

(6) Albert Gu, Work done while at Google DeepMind;

(7) Ruba Haroun, Google DeepMind;

(8) Leonard Berrada, Google DeepMind;

(9) Yutian Chen, Google DeepMind;

(10) Srivatsan Srinivasan, Google DeepMind;

(11) Guillaume Desjardins, Google DeepMind;

(12) Arnaud Doucet, Google DeepMind;

(13) David Budden, Google DeepMind;

(14) Yee Whye Teh, Google DeepMind;

(15) David Budden, Google DeepMind;

(16) Razvan Pascanu, Google DeepMind;

(17) Nando De Freitas, Google DeepMind;

(18) Caglar Gulcehre, Google DeepMind.

1 Introduction

2 Model Architecture

3 Recurrent Models Scale as Efficiently as Transformers

3.1. Scaling curves

3.2. Evaluation on downstream tasks

4 Training Recurrent Models Efficiently on Device and 4.1. Model parallelism for large scale training

4.2. Efficient linear recurrences on device

4.3. Training speed on longer sequences

5. Inference Speed

5.1. A simple model of the decode step

5.2. Results

6. Long Context Modeling and 6.1. Improving next token prediction with longer contexts

6.2. Copy and retrieval capabilities

7. Related Works

8. Conclusion, Acknowledgements, and References


A. RG-LRU Recurrence Gate

B. Complex-Gated Linear Recurrent Unit (CG-LRU)

C. Model Scale Hyper-Parameters

D. Efficient Linear Recurrences on Device

E. The Local Attention Window Size of Griffin

F. Inference Speeds

G. Improving Next Token Prediction with Longer Contexts: Additional Results

H. Additional Details of the Copy and Retrieval Tasks

E. The Local Attention Window Size of Griffin

Griffin uses both recurrent blocks as well as local attention layers in its temporal mixing blocks. For all experiments previously shown using a training sequence length of 2048, we use a local attention window size of 1024. We now investigate how the performance of different window sizes for the local attention layer varies with the training sequence length.


We consider 400M parameter models trained on sequence lengths of 2048, 4096 and 8192 tokens,


Figure 9 | Performance of 400M parameter Griffin and MQA Transformer models using different local attention window sizes and different training sequence lengths. The window sizes of the local attention layers are shown above each bar in the plot. We notice that a global attention MQA Transformer is much better than local attention variants of the MQA Transformer (where the window size is smaller than the training sequence length). Furthermore, we see that using a fixed local attention window size of 1024 (denoted ‘1K’ in the plot) for the Griffin model outperforms all global attention and local attention MQA Transformer baselines across all training sequence lengths.


where we keep the total number of training tokens fixed. For each sequence length, we train Griffin models using different local attention window sizes. As baselines, we train MQA Transformers using global attention layers, as well MQA Transformers using local attention layers with different window sizes. The results are shown in Figure 9, where the window sizes used are shown on top of each bar (MQA Transformer bars with window size equal to the training sequence length are the global attention MQA Transformer baseline).


From Figure 9, we see that remarkably, even when using a fixed window size of 1024 for the local attention layers in Griffin, it outperforms the global attention MQA Transformer baseline across all sequence lengths tested. However, it is worth noting that the performance gap between Griffin with local attention window 1024 and the global attention MQA Transformer reduces as the sequence length grows. Therefore, if the sequence length grows further, it is likely important to slowly also grow the local attention window size. In practice, the hardware used will also heavily determine the optimal local attention window size in terms of training and inference speed. Finally, we note that MQA Transformers purely using local attention (window sizes less than the training sequence length) perform significantly worse than both global attention MQA Transformers, as well as Griffin.


F. Inference Speeds

F.1. Estimating memory-boundedness

The inference speed of language models at decode time is bounded by memory loading. As described already in 4.2 the linear RNN is memory bound. In the following we will show this is true for the other components (which are linear layers and self-attention) in our recurrent models and Transformer models.

F.2. Estimating the memory boundedness of linear layers

As shown in D.1 the outer dimension (usually consisting of batch 𝐵 and sequence length 𝑇 dimensions) must be at least 136 in order to be compute bound. At decode time 𝑇 =1 and if we assume 𝐵≲128 then any linear layers will be memory bound at decode time.

F.3. Estimating the memory boundedness of self-attention

In the following, we calculate the ratio of memory accesses to arithmetic operations for the attention computation for the 𝐿-th decode step, to show it is also memory-bound.


To simplify the following analysis, we assume that we start from an empty prompt (or equivalently assume that the prefill contains 0 tokens).




F.4. Cache sizes

In the following we do an analysis of the relative sizes of caches used in our recurrent and Transformers. All caches sizes scale linearly with batch size and in the following we assume 𝐵=1.


F.4.1. The size of the KV cache





For either MHA or MQA the size of the KV cache can exceed the number of model parameters when the sequence length 𝑇 is large. We therefore expect to observe a transition from a ‘parameter bound’ regime when the sequence length is short, during which the decoding speed is dominated by the time taken to load the model parameters on device, to a ‘cache bound’ regime for large sequences, where the decoding speed is dominated by the time taken to load the KV cache.


F.4.2. The size of the recurrent state





F.4.3. The local attention cache




G. Improving Next Token Prediction with Longer Contexts: Additional Results

Figure 10 shows an additional result demonstrating next token prediction performance at different context lengths on a held out dataset of arXiv articles. We find that the results on this dataset are qualitatively similar to the results shown in Figure 5.



Figure 10 | The evaluation performance of 1B parameter models across a range of sequence lengths on held-out evaluation sets of ArXiv articles. On the left, we compare the performance of different models trained with sequence length 2048, evaluated with a sequence length of up to 32,768. On the right, we compare Griffin and Hawk when trained respectively on 2048 (2k) and 8192 (8k) sequence lengths. Results are qualitatively similar to the evaluation on Books presented in Figure 5.


H. Additional Details of the Copy and Retrieval Tasks

Figure 11 is an illustration of the Selective Copying and Induction Heads tasks.


In the Selective Copying task, the model needs to learn to copy data tokens (coloured tokens in Figure 11) from a sequence while ignoring noise tokens (white tokens in Figure 11). Crossed out tokens in the output in Figure 6 denote tokens that are masked out in the loss.



Figure 11 | An illustration of the Selective Copying (left) and the Induction Heads tasks (right).



In the Induction Heads task, the model needs to learn to recall the token immediately following a special token (black token in Figure 11). As before, crossed out tokens in the output denote tokens that are masked out in the loss.

This paper is available on arxiv under CC BY 4.0 DEED license.


Trending Topics

blockchaincryptocurrencyhackernoon-top-storyprogrammingsoftware-developmenttechnologystartuphackernoon-booksBitcoinbooks