Continual Quantization-Aware Pre-Training: When to transition from 16-bit to 1.58-bit pre-training for BitNet language models?
Journal:
arXiv
Published Date:
Feb 17, 2025
Abstract
Large language models (LLMs) require immense resources for training and
inference. Quantization, a technique that reduces the precision of model
parameters, offers a promising solution for improving LLM efficiency and
sustainability. While post-training quantization methods typically achieve 4-8
bits per parameter, recent research suggests that training LLMs with 1.58 bits
per weight parameter from scratch can maintain model accuracy while greatly
reducing memory requirements and energy consumption at inference time. Here, we
investigate a training strategy for quantization-aware pre-training, where the
models are first trained with 16-bit precision and then transition into
1.58-bit quantization-aware training. Our results on 11 downstream tasks show
that this 16-to-1.58-bit training strategy is preferable over full 1.58-bit
training and leaves models closer to those which have undergone 16-bit
training. We further investigate the effects of retaining the optimizer state
at the transition point and gradually phasing in quantization strength --
finding that both techniques alleviate the magnitude of loss spikes, but also
that these effects can be compensated through further training.