A universal network strategy for lightspeed computation of entropy-regularized optimal transport.
Journal:
Neural networks : the official journal of the International Neural Network Society
PMID:
39705771
Abstract
Optimal transport (OT) is an effective tool for measuring discrepancies in probability distributions and histograms of features. To reduce its high computational complexity, entropy-regularized OT is proposed, which is computed through Sinkhorn algorithm and can be readily integrated into neural networks. However, each time the parameters of networks are updated, both the value and derivative of OT need to be calculated. When there is a relatively high demand for solving accuracy, the number of layers in the computation graph for Sinkhorn algorithm is fairly large, requiring plenty of time and memory. To address this problem, we propose a novel network strategy to estimate the transport matrix instead of Sinkhorn algorithm, which significantly reduces the computation graph size. Compared with other neural OT, our method is suitable for arbitrary cost functions and varying marginal distributions. To avoid numerical instability induced by a small regularization coefficient, we devise the new method in log domain with the dual form. We estimate the error bound of the resulting algorithm for approximate inputs in theory, and extend our approach to robust OT and barycenter computation in practice. Extensive experiments show that our method outperforms baselines on both required computation cost and accuracy significantly.