Skip to content
Go back 2505.01618 arXiv logo

Don't be lazy: CompleteP enables compute-efficient deep transformers

Published:  at  11:16 AM
81.10 🤔

This paper introduces CompleteP, a parameterization for transformers with α = 1, which ensures depth-wise hyperparameter transfer and complete feature learning, achieving 12-34% compute efficiency improvements and enabling a wider range of compute-optimal width-to-depth ratios.

Large Language Model, Transformer, Pre-training, Efficiency, Scaling Laws

Nolan Dey, Bin Claire Zhang, Lorenzo Noci, Mufan Li, Blake Bordelon, Shane Bergsma, Cengiz Pehlevan, Boris Hanin, Joel Hestness

Cerebras Systems, ETH Zurich, Princeton University, Harvard University

Generated by grok-3

Background Problem

The research addresses the challenge of compute efficiency in training large language models (LLMs), where scaling model size (width and depth) often requires costly re-tuning of hyperparameters (HPs) like learning rates, leading to suboptimal resource use. Existing parameterizations, such as the maximal update parameterization (µP), achieve HP transfer across width but struggle with depth, often resulting in ‘lazy learning’ where deeper layers fail to learn non-linear features effectively. This work aims to solve the problem of achieving stable HP transfer across both depth and width while ensuring non-lazy, effective feature learning in all layers, ultimately improving compute efficiency and model performance.

Method

The core method introduced is CompleteP, a parameterization for transformer models defined by setting the depth-dependent re-scaling factor α to 1 in the residual block update equation: h+1=h+L1F(h)\mathbf{h}^{\ell+1} = \mathbf{h}^{\ell} + L^{-1} \, \mathcal{F}_{\ell}(\mathbf{h}^{\ell}), where L is the depth and F_ℓ represents the residual block (MLP or attention). This method extends prior work by incorporating specific re-scalings for LayerNorm, bias learning rates, AdamW’s weight decay (λ), and ϵ as functions of depth and width. The implementation ensures that optimal HPs remain stable when scaling model size, avoiding the need for re-tuning, and promotes complete feature learning by maintaining non-linear dynamics in all layers. Key steps include adjusting initialization variances, learning rates, and optimizer parameters based on width multiplier (m^N) and depth multiplier (m^L), as detailed in the paper’s Table 1, to achieve consistent training dynamics across scales.

Experiment

The experiments were conducted on decoder-only pre-LN transformer models trained on the SlimPajama dataset, with model sizes ranging from 75M to 1.9B parameters, following a compute-optimal setup of 20 tokens per parameter (TPP). The setup varied width (N) and depth (L) to study compute-optimal N:L ratios, using batch sizes based on FLOP power laws and well-tuned HPs. Results showed that CompleteP (α = 1) consistently outperformed standard parameterization (SP), µP, and α = 0.5 in terms of HP transfer stability across depths, achieving 11.8% to 34.4% FLOP savings over µP, especially in deeper models. It also enabled a wider range of compute-efficient N:L ratios (even N:L ≈ 10), contrasting with prior findings favoring N:L ≈ 100. Downstream zero-shot evaluations on tasks like Hellaswag and LAMBADA confirmed upstream gains translated to better performance. The experimental design was comprehensive for the tested scales and settings, aligning with compute-optimal expectations from prior literature, though results might be specific to pre-LN architectures and the chosen dataset. The superiority in FLOP savings and depth-wise performance matches the expectation of avoiding lazy learning, but broader architectural and dataset testing is needed for robustness.

Further Thoughts

The concept of complete feature learning introduced by CompleteP opens up intriguing avenues for further exploration, particularly in how it might interact with other scaling strategies like Mixture of Experts (MoE) or parallel sub-networks, which the authors deemed out of scope. Could CompleteP’s depth-wise efficiency be combined with MoE’s width-wise scaling to push compute efficiency even further, especially for ultra-large models? Additionally, the hardware implications of favoring narrow-deep models with CompleteP are significant—while beneficial for low-memory settings via weight streaming, they might exacerbate latency issues in inference-heavy applications. This trade-off warrants deeper investigation, perhaps in collaboration with hardware optimization studies. I’m also curious about the applicability of CompleteP to emerging architectures like state space models, which are gaining traction for their efficiency in handling long contexts. If CompleteP’s principles of non-lazy learning can be adapted to such models, it could redefine scaling strategies across diverse AI domains. Lastly, the theoretical desiderata for HP transfer, especially complete feature learning, could inspire new metrics for evaluating model training dynamics, potentially bridging gaps between theoretical and empirical scaling laws research.



Previous Post
Toward Understanding In-context vs. In-weight Learning
Next Post
To CoT or not to CoT? Chain-of-thought helps mainly on math and symbolic reasoning