Overcoming HBM-VMEM Bottlenecks in TPU-v3 Recurrent Workloads

by Gating TechnologyMay 6th, 2025
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

Novel recurrence gates and complex-valued units boost stability and efficiency in linear recurrent models, optimized for TPU-v3 hardware.
featured image - Overcoming HBM-VMEM Bottlenecks in TPU-v3 Recurrent Workloads
Gating Technology HackerNoon profile picture
0-item

Authors:

(1) Soham De, Google DeepMind and with Equal contributions;

(2) Samuel L. Smith, Google DeepMind and with Equal contributions;

(3) Anushan Fernando, Google DeepMind and with Equal contributions;

(4) Aleksandar Botev, Google DeepMind and with Equal contributions;

(5) George Cristian-Muraru, Google DeepMind and with Equal contributions;

(6) Albert Gu, Work done while at Google DeepMind;

(7) Ruba Haroun, Google DeepMind;

(8) Leonard Berrada, Google DeepMind;

(9) Yutian Chen, Google DeepMind;

(10) Srivatsan Srinivasan, Google DeepMind;

(11) Guillaume Desjardins, Google DeepMind;

(12) Arnaud Doucet, Google DeepMind;

(13) David Budden, Google DeepMind;

(14) Yee Whye Teh, Google DeepMind;

(15) David Budden, Google DeepMind;

(16) Razvan Pascanu, Google DeepMind;

(17) Nando De Freitas, Google DeepMind;

(18) Caglar Gulcehre, Google DeepMind.

1 Introduction

2 Model Architecture

3 Recurrent Models Scale as Efficiently as Transformers

3.1. Scaling curves

3.2. Evaluation on downstream tasks

4 Training Recurrent Models Efficiently on Device and 4.1. Model parallelism for large scale training

4.2. Efficient linear recurrences on device

4.3. Training speed on longer sequences

5. Inference Speed

5.1. A simple model of the decode step

5.2. Results

6. Long Context Modeling and 6.1. Improving next token prediction with longer contexts

6.2. Copy and retrieval capabilities

7. Related Works

8. Conclusion, Acknowledgements, and References


A. RG-LRU Recurrence Gate

B. Complex-Gated Linear Recurrent Unit (CG-LRU)

C. Model Scale Hyper-Parameters

D. Efficient Linear Recurrences on Device

E. The Local Attention Window Size of Griffin

F. Inference Speeds

G. Improving Next Token Prediction with Longer Contexts: Additional Results

H. Additional Details of the Copy and Retrieval Tasks

A. RG-LRU Recurrence Gate

In Figure 7, we demonstrate the behavior of different gating mechanisms applied on the recurrent weight a.


Figure 7 | The behaviour of different gating mechanisms applied on the recurrent weight 𝑎 (note that in the Mamba’s notations this is −𝐴).


Implementation We implement our recurrence gate, as defined in Section 2.4, in a slightly different, but mathematically equivalent form, for numerical stability. In particular, we compute the logarithm of 𝑎𝑡 and then we exponentiate it, instead of computing a sigmoid and then taking a power:



B. Complex-Gated Linear Recurrent Unit (CG-LRU)




its first half as the real part of a complex vector, and the second part as the imaginary part of the same complex vector:





With this we rewrite the equations for the LRU (see eq. 4) as:




C. Model Scale Hyper-Parameters

In Table 2, we present the hyper-parameters of the models at different scales. These hyperparameters are shared for all the model families that we explored in this paper.



Table 2 | Key model hyper-parameters considered for different model sizes. These hyperparameters are shared across different architectures we tested.


D. Efficient Linear Recurrences on Device

The initial step in computational optimization lies in identifying the primary performance bottleneck on the target hardware. For most accelerators, the key limiting factors are computational throughput (FLOPs/s) and memory bandwidth between the high-bandwidth memory (HBM) and the fast vector memory (VMEM). While factors like HBM capacity and host-device communication are relevant, techniques such as ZeRO sharding and pipelined data transfer offer practical mitigations. Modern accelerator designs often prioritize a high FLOPs-to-byte ratio to accommodate workloads where computations significantly outnumber memory transfers. We show the key specification of the TPU-v3 pod (two chips per pod) in Table 3, which we use for all our experiments.



Table 3 | Hardware specifications for a TPU-v3 pod.




Figure 8 | a) Runtimes of different implementations of the scan operation on a TPU-v3 at different sequence lengths. The batch size of the input is fixed at 8 and the dimension of each token is 1024. b) Relative runtimes of the Hawk model when using different imp


D.2. Scan runtimes



This paper is available on arxiv under CC BY 4.0 DEED license.


Trending Topics

blockchaincryptocurrencyhackernoon-top-storyprogrammingsoftware-developmenttechnologystartuphackernoon-booksBitcoinbooks