Reparametrization trick
In machine learning, the reparametrization trick is a method to construct unbiased estimators of gradients, typically used for performing gradient descent on expectations.
It was introduced independently[1] to variational inference by Kingma & Welling (2014)[2], Rezende et al. (2014)[3], and Titsias & Lazaro-Gredilla (2014)[4].
Mathematical principle
A problem commonly encountered in machine learning is estimating gradient-of-expectation. The problem is as follows:
Given a family of distributions[note 1]
, where each
defines a probability measure
on a base space
.
Given any nice (nice enough to allow differentiating under the integral sign) function
, the gradient-of-expectation is:
-
(main)
The key problem is to estimate the gradient-of-expectation efficiently. There are several possible solutions.[note 2]
Analytic
If the expectation has closed-form solution, then its gradient also has closed-form solution. This is essentially the only method used by statisticians before large-scale computing was available.
There are several problems with this. One, the closed-form solution rarely exists. Two, the closed-form solution might be very expensive to compute, especially when has a large dimension (curse of dimensionality). The typical solution for such intractable problems is Monte Carlo method.
Monte Carlo method
The naive method to estimating Equation (main) is to sample many , estimate , then vary slightly to , estimate , and so on, and finally fitting a linear operator such thatThis is highly inefficient. To save time, one should instead compute the gradient exactly where it is tractable, instead of estimating the gradient everywhere. In most commonly used examples, while is intractable, are tractable (they are often designed by practitioners to be tractable).
As such, we consider the following expansion of Equation (main):
-
(Monte Carlo)
Instead of estimating the integrals, then estimating the gradient, now we compute the gradients exactly, then estimate the integrals. The integrals are still intractable in general, but can be done by Monte Carlo integration. To perform Monte Carlo integration, an integral
must have a probability distribution to sample
from. If we use
for Monte Carlo integration, we obtain
and thus
-
(REINFORCE)
This is the equation used in policy gradient methods, to be detailed below.
The reparametrization trick
The main issue in estimating Equation (main) is an "entanglement" between the distribution
and the expectation of
to be estimated. The entanglement comes to the fore when we vary
. The reparametrization trick pushes all dependence on
into a deterministic function, and then perform the obvious
The prototypical example is the family of 1D normal distributions:
. Given any
, we can sample
as
, with
and
.
Remark: The idea is similar to probability transforms such as the Box–Muller transform, where we have only one "seed" random number generator, and must construct other probability distributions by performing deterministic transforms on the random numbers generated by it.
In general, given a family of distribution
, we can perform the reparametrization trick if there exists a seed distribution
, and a transform function
, such that for any
, we can sample from
by sampling
, then compute
. That is,
. With the seed distribution and the transform, we obtain the reparametrization trick equation:
-
(reparametrization trick)
Motivation
Many problems in statistics and machine learning are of the form: find the "best" parameters. Usually, "best" is defined as "achieving minimal loss" or "maximum reward", and taking the gradient, if possible, is useful to optimization.
Parameter estimation in statistics
For example, consider the parameter estimation problem: Given data sampled from a distribution , or at least from a distribution close enough to some . The problem is to estimate . As we will see below, this often reduces to solving one of the following problems:This is no loss of generality, because if we have a sequence of independently sampled data , then we can simply consider them as one big data with distribution :First approach (Blackwell 1951)[5][6]: find a distribution that can best mock-up the observed samples.where measures the difference between two points in . It can be a metric, or be more general.
Set , then this problem reduces to:
Second approach (frequentist estimation): define a loss function , which measures the difference between two "points" (each point is a probability distribution, very big points indeed) in the space of distributions under consideration.
Then, minimize estimation risk (expected loss):where is an estimator, parametrized by .
The above definition is not very interesting, since we could define the following "blind guess" estimator . It would happen to be exactly right if . Thus, we must additionally require the estimator to perform well on many different possible .
There are many possible ways to formalize the idea of "perform well on many different ". A common version is the following:Set , then this problem reduces to:which is almost in the form of , but not quite.
For example, when is a subset of , , and the set of estimators contains only unbiased estimators, then if the minimum-variance unbiased estimator exists, it is the solution to the above problem.[note 3]
Third approach (Bayesian estimation): Imposing a maximum gives frequentist estimation a kind of "game-theoretic" flavor, since it is formally equivalent to a zero-sum game between a statistician and nature. The statistician proposes an estimator , and nature replies with .
However, nature is uninterested, as the statistician's choice of has no effect on . Consequently, Bayesian estimation models this by imposing a prior distribution over , and optimizing the following:Set , then this problem reduces to:which is in the form of .
Reinforcement learning
In reinforcement learning, there is an "entanglement" between the distribution and the function.
To perform a gradient descent, one must estimate the gradient.
Main methods
There are many different ways to perform reparametrization trick, for diverse purposes.
Reparametrizing a distribution family
Given a family of distributions , we can apply the reparametrization trick if we have a way to generate the family into a constant "seed random generator" and a family of parametrized deterministic functions.
For example, the family of normal distributions on is . Given any , we can sample as , with and , where .
Since is non-negative-definite, is guaranteed to exist by the spectral theorem, but it is not unique. It can be found by Cholesky decomposition, or singular value decomposition. Different choices have different theoretical and practical advantages.[7]
Gumbel max tricks
The prototype of Gumbel tricks is the Gumbel-max trick, which allows one to create any categorical distribution using just a Gumbel distribution random number generator.
The Gumbel-max trick generalizes to other distributions:[8]
Theorem — Gumbel
This provides a proof that the Gumbel, Weibull, and Fréchet distributions are max-stable, which is one-half of the extreme value theorem.
Gumbel softmax method
The Gumbel-max trick allows sampling from the categorical distribution, but it cannot be used for training with gradient descent, because . This is because is "hard", that is, insensitive for small variations of . Intuitively speaking, this can be interpreted as saying that the model simply predicts "category is most likely" without saying by how much it is the most likely category.
To create better gradients, the model should predict a distribution over the categories, and the standard method is the softmax function, creating the Gumbel softmax method:[1][9]After imposing a good loss function, such as the cross-entropy loss, one can train the method by standard gradient descent:where is the correct label.
Estimating bounds of partition functions
Inspired by the connection between information theory and statistical mechanics[10], energy-based models, such as the Boltzmann machine and the deep belief network, are statistical models defined by an energy function (or "potential function") and a temperature.
Consider a set of discrete random variables , each being the state of a particle . Let the system of particles interact, and let the energy of the entire system be , when . Then, when the system is in contact with a heat bath with temperature , after reaching equilibrium, the distribution of the states of the system is the Boltzmann distribution:where is the inverse temperature of the heat bath.
The normalizing constant is the partition function of the system. It depends on the temperature and the energy function.
In general, partition functions are intractable, making it important to estimate it in practice. There are a family of reparametrization tricks for accomplishing the estimation.
Upper bounds[11]
Lower bounds[12]
Applications
Variational autoencoder
Two similar methods
Inference compilation
Amortized inference
Related methods
The reparametrization trick is a general technique
REINFORCE
The policy gradient method in reinforcement learning, proposed in (Williams, 1992),[15] uses the following expansion of Equation (main):
-
(policy gradient)
It is also called the "likelihood ratio method" and the "score function method".
The policy gradient method has many variants, such as with function approximation[16], deterministic[17]. An introduction is [18].
Notes
References
- ↑ 1.0 1.1 Maddison, C.; Mnih, A.; Teh, Y. (2019). "The concrete distribution: A continuous relaxation of discrete random variables". Proceedings of the International Conference on Learning Representations.
- ↑ Kingma, Diederik P.; Welling, Max (2014-05-01). "Auto-Encoding Variational Bayes". arXiv:1312.6114 [stat.ML].
- ↑ Rezende, Danilo Jimenez; Mohamed, Shakir; Wierstra, Daan (2014-06-18). "Stochastic Backpropagation and Approximate Inference in Deep Generative Models". International Conference on Machine Learning. PMLR: 1278–1286. arXiv:1401.4082.
- ↑ Titsias, Michalis; Lázaro-Gredilla, Miguel (2014-06-18). "Doubly Stochastic Variational Bayes for non-Conjugate Inference". International Conference on Machine Learning. PMLR: 1971–1979.
- ↑ Blackwell, David (1951-01-01). "Comparison of Experiments". Proceedings of the Second Berkeley Symposium on Mathematical Statistics and Probability. 2: 93–103.
- ↑ Keener, Robert W. (2010). Theoretical Statistics: Topics for a Core Course (Springer Texts in Statistics). p. 44. ISBN 978-1461426707. Search this book on
- ↑ Kessy, Agnan; Lewin, Alex; Strimmer, Korbinian (2018-10-02). "Optimal Whitening and Decorrelation". The American Statistician. 72 (4): 309–314. doi:10.1080/00031305.2016.1277159. ISSN 0003-1305. Unknown parameter
|s2cid=ignored (help) - ↑ Balog, Matej; Tripuraneni, Nilesh; Ghahramani, Zoubin; Weller, Adrian (2017-07-17). "Lost Relatives of the Gumbel Trick". International Conference on Machine Learning. PMLR: 371–379. arXiv:1706.04161.
- ↑ Eric, Jang; Shixiang, Gu; Ben, Poole (April 2017). "Categorical Reparametrization with Gumble-Softmax". ICLR 2017 - Conference Track.
- ↑ Jaynes, E. T. (1957-05-15). "Information Theory and Statistical Mechanics". Physical Review. 106 (4): 620–630. Bibcode:1957PhRv..106..620J. doi:10.1103/PhysRev.106.620.
- ↑ Hazan, Tamir; Jaakkola, Tommi (2012-06-27). "On the Partition Function and Random Maximum A-Posteriori Perturbations". arXiv:1206.6410 [cs.LG].
- ↑ Hazan, Tamir; Maji, Subhransu; Jaakkola, Tommi (2013). "On Sampling from the Gibbs Distribution with Random Maximum A-Posteriori Perturbations". Advances in Neural Information Processing Systems. Curran Associates, Inc. 26. arXiv:1309.7598.
- ↑ Le, Tuan Anh; Baydin, Atilim Gunes; Wood, Frank (2017-04-10). "Inference Compilation and Universal Probabilistic Programming". Artificial Intelligence and Statistics. PMLR: 1338–1348. arXiv:1610.09900.
- ↑ Le, Tuan Anh (19 December 2017). "Amortized Inference". www.tuananhle.co.uk. Archived from the original on 2022-06-26. Retrieved 2022-06-26.
- ↑ Williams, Ronald J. (1992), "Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning", Reinforcement Learning, Boston, MA: Springer US, pp. 5–32, doi:10.1007/978-1-4615-3618-5_2, ISBN 978-1-4613-6608-9, retrieved 2022-06-26
- ↑ Sutton, Richard S; McAllester, David; Singh, Satinder; Mansour, Yishay (1999). "Policy Gradient Methods for Reinforcement Learning with Function Approximation". Advances in Neural Information Processing Systems. MIT Press. 12.
- ↑ Silver, David; Lever, Guy; Heess, Nicolas; Degris, Thomas; Wierstra, Daan; Riedmiller, Martin (2014-01-27). "Deterministic Policy Gradient Algorithms". International Conference on Machine Learning. PMLR: 387–395.
- ↑ Sutton, Richard S. (2018). "13". Reinforcement learning : an introduction. Andrew G. Barto (2 ed.). Cambridge, Massachusetts. ISBN 978-0-262-03924-6. OCLC 1043175824. Search this book on
This article "Reparametrization trick" is from Wikipedia. The list of its authors can be seen in its historical and/or the page Edithistory:Reparametrization trick. Articles copied from Draft Namespace on Wikipedia could be seen on the Draft Namespace of Wikipedia and not main one.
