Online Reward-Weighted Fine-Tuning of Flow Matching with Wasserstein Regularization
Journal:
arXiv
Published Date:
Feb 9, 2025
Abstract
Recent advancements in reinforcement learning (RL) have achieved great
success in fine-tuning diffusion-based generative models. However, fine-tuning
continuous flow-based generative models to align with arbitrary user-defined
reward functions remains challenging, particularly due to issues such as policy
collapse from overoptimization and the prohibitively high computational cost of
likelihoods in continuous-time flows. In this paper, we propose an easy-to-use
and theoretically sound RL fine-tuning method, which we term Online
Reward-Weighted Conditional Flow Matching with Wasserstein-2 Regularization
(ORW-CFM-W2). Our method integrates RL into the flow matching framework to
fine-tune generative models with arbitrary reward functions, without relying on
gradients of rewards or filtered datasets. By introducing an online
reward-weighting mechanism, our approach guides the model to prioritize
high-reward regions in the data manifold. To prevent policy collapse and
maintain diversity, we incorporate Wasserstein-2 (W2) distance regularization
into our method and derive a tractable upper bound for it in flow matching,
effectively balancing exploration and exploitation of policy optimization. We
provide theoretical analyses to demonstrate the convergence properties and
induced data distributions of our method, establishing connections with
traditional RL algorithms featuring Kullback-Leibler (KL) regularization and
offering a more comprehensive understanding of the underlying mechanisms and
learning behavior of our approach. Extensive experiments on tasks including
target image generation, image compression, and text-image alignment
demonstrate the effectiveness of our method, where our method achieves optimal
policy convergence while allowing controllable trade-offs between reward
maximization and diversity preservation.