This paper introduces a Rao-Blackwellized Monte Carlo estimator for KL divergence between language models, achieving unbiased estimates with provably lower variance than standard Monte Carlo methods, and demonstrates improved stability and performance in RLHF fine-tuning for sentiment-controlled generation.
Large Language Model, Reinforcement Learning, RLHF, Efficiency, Robustness
Afra Amini, Tim Vieira, Ryan Cotterell
ETH Zurich
Generated by grok-3
Background Problem
The paper addresses the challenge of estimating the Kullback-Leibler (KL) divergence between two language models, a critical task in natural language processing with applications in reinforcement learning from human feedback (RLHF), interpretability, evaluation metrics, and knowledge distillation. Exact computation of KL divergence is intractable for neural language models due to the infinite nature of possible string sequences, leading practitioners to rely on sampling-based estimators like Monte Carlo (MC), which suffer from high variance and can produce negative estimates despite KL being non-negative. The key problem solved is the reduction of variance in KL divergence estimation while maintaining unbiasedness, thereby improving stability in applications like RLHF where KL serves as a regularization term to prevent reward over-optimization and ensure model fluency.
Method
The paper proposes a Rao-Blackwellized Monte Carlo estimator for KL divergence estimation between language models. The core idea is to apply Rao-Blackwellization, a variance reduction technique from statistics, by conditioning the estimator on the prefix of sampled strings at each token position, thus computing the exact KL divergence between conditional distributions for each step. The implementation involves: (1) sampling strings from the first language model, (2) at each token position, computing the expected KL divergence conditioned on the prefix using exact calculations over the vocabulary, and (3) summing these step-wise estimates to obtain the final KL estimate. This method ensures unbiasedness and guarantees variance less than or equal to the standard MC estimator, while maintaining the same computational complexity, as the additional per-token calculations do not increase the overall runtime order. Additionally, a Rao-Blackwellized estimator for the gradient of KL divergence is derived, crucial for optimization in RLHF, using a similar conditioning approach to reduce variance in gradient estimates.
Experiment
The experiments are conducted on a sentiment-controlled generation task using a GPT-2 model fine-tuned on the IMDB dataset to generate positive movie reviews, with the reference model being the pre-fine-tuned GPT-2. The setup evaluates KL divergence estimators (Monte Carlo (MC), Control Variate (CV), Horvitz-Thompson (HT), and Rao-Blackwellized (RB)) by sampling 4000 responses for 512 prompts from the IMDB evaluation set, assessing bias, variance, and consistency across different sample sizes (M=1,5,10). Results show that the RB estimator achieves the lowest standard deviation (e.g., 0.03 at M=10 compared to 0.05 for MC), confirming significant variance reduction while remaining unbiased with an expected KL of 6.76. In RLHF training dynamics, using the RB estimator for gradient computation leads to more stable training across multiple runs, with models appearing on the Pareto frontier of reward vs. KL 76% of the time (95% for KL<5), compared to MC-trained models, indicating better balance between reward maximization and KL constraint. The experimental design is reasonable for the sentiment task but limited to one model and dataset, raising questions about generalizability to other architectures or domains. The results match the expectation of variance reduction, though practical computational overhead for larger vocabularies remains unaddressed.
Further Thoughts
The introduction of the Rao-Blackwellized estimator opens up intriguing possibilities for enhancing other areas of machine learning beyond RLHF, such as in knowledge distillation where precise measurement of distributional differences between teacher and student models could improve compression techniques. I am curious about its applicability to multimodal models, where KL divergence might be used to align distributions across different modalities (e.g., text and image), potentially reducing variance in cross-modal alignment tasks. Additionally, connecting this work to interpretability research, as hinted in the paper, suggests a pathway to better understand how prompts shift model distributions—could this estimator help detect and mitigate biases by providing more stable divergence measures in response to biased prompts? A potential limitation to explore is the scalability of exact KL computations at each token step for very large vocabularies or in low-resource settings; perhaps hybrid approaches combining RB with approximation techniques could be investigated. Lastly, relating this to recent works on emergent abilities in large language models, I wonder if stable KL estimation could provide insights into how fine-tuning impacts emergent reasoning or planning capabilities by quantifying distributional shifts more reliably.