본문 바로가기

Deep Learning/Common

Neural tangent kernel (NTK) and beyond

728x90

"Meanwhile, we see the evolving development of deep learning theory on neural networks. NTK (neural tangent kernel) is proposed to characterize the gradient descent training dynamics of infinite wide (Jacot et al., 2018) or finite wide deep networks (Hanin & Nica, 2019). Wide networks are also proved to evolve as linear models under gradient descent (Lee et al., 2019). This is further leveraged to decouple the trainability and generalization of networks (Xiao et al., 2019)."

"By formulating neural networks as a Gaussian Process (no training involved), the gradient descent training dynamics can be characterized by the Neural Tangent Kernel (NTK) of infinite (Lee et al., 2019) or finite (Yang, 2019) width networks, from which several useful measures can be derived to depict the network trainability at the initialization."

1. Neural Tangent Kernel: Convergence and Generalization in Neural Networks

Author: Arthur Jacot et al.
Post: rajatvd.github.io/NTK/
        blog.ml.cmu.edu/2019/10/03/ultra-wide-deep-nets-and-the-neural-tangent-kernel-ntk/
Video: www.youtube.com/watch?v=DObobAnELkU

The keypoint of the paper is that neural network with infinite width simplifies to linear models with a kernel called the neural tangent kernel (NTK). This so called, NTK can explain the dynamics of the neural network output in a linear fashion and provides a close-form solution for it, assuming that the weight of the network has infinite width. The intuition behind this is that, the weight of the infinite-width network does not change from the initial weight state, hence the network output which depends on the weight, can be approximated by taylor expansion (NTK remains a constant over the course of training). 

Although, NTK does not fully explain the success of deep neural network, and it still fails to explain in more complex non-linear models, they offer a interesting new perspective on neural network learning dynamics. There are many works that have stemmed from this work related to generalization, training dynamics, and network architecture search.

---------------  Below  mathemtical description is from the post of rajatvd.github.io/NTK/. ----------------

Let L(w) be the loss of a 2-hidden layer neural network,

$L(w) = \frac{1}{N}\sum_{i=1}^{N}\frac{1}{2}(f(\bar{x}_i,w)-\bar{y}_i))^2$,

where f is the neural network, w is the weight, and $(\bar{x}_i, \bar{y}_i)$ is one of our N datapoint. Removing the constant N, and omitting the variable x, we have

(1) $L(w) = \frac{1}{2}\left \| y(w)-\bar{y}) \right \|^2$.

Now, because the weight dynamics of infinite-width neural network becomes almost zero, we can just taylor expand the network function with respect to the weights around its initialization.

$f(x,w)\approx f(x,w_0) + \bigtriangledown_w f(x,w_0)^T(w-w_0)$.

Again simplifying with concise vector notation, we can rewrite as,

(2) $y(w) - y(w_0) \approx \bigtriangledown_w y(w_0)^T(w-w_0)$.

Because initial output $y(w_0)$ and the model Jacobian $\bigtriangledown_w y(w_0)$ are just constants, the above equation becomes linear in weights. But it is still non-linear in the input $x$, as the finding the gradient of the model is definitely not a linear operations. Refer to the kernel method, this is a linear model using a feature map $\phi(x)$ as below.

(3) $\phi(x) = \bigtriangledown_w f(x,w_0) = \bigtriangledown_w y(w_0)$.

This feature map naturally induces a kernel on the input, which is called the neural tangent kernel. Now to derive the neural tangent kernel, we look at the equation (2) as in the continuous domain, making it as a differentiable problem, we get,

(4) $\dot{y}(w) = \bigtriangledown y(w)^T\dot{w}$.

In order to get $\dot{w}$, we revisit the weight update under gradient descent.

$w_k+1 = w_k - \eta \bigtriangledown_w L(w_k)$.

Rewriting this equation, we get,

$\frac{w_k+1 - w_k}{\eta} = -\bigtriangledown_w L(w_k)$.

Again, by chaning it to the form of differential equation (This can be seen as we take the learning rate to be infinitesimally small), we get,

$\frac{dw(t)}{dt} = \dot{w}(t) = - \bigtriangledown_w L(w(t))$.

By substituting the $L(w)$ from equation (1) and taking the gradient, we get,

$\dot{w} = -\bigtriangledown y(w)(y(w) - \bar{y})$

Now, by plugging $\dot{w}$ to the equation (4), we can get,

$\dot{y}(w) = \bigtriangledown y(w)^T\dot{w} = -\bigtriangledown y(w)^T\bigtriangledown y(w)(y(w) - \bar{y})$

where $H(w) = \bigtriangledown y(w)^T\bigtriangledown y(w)$ is called the neural tangent kernel (NTK). If we go back to the equation (3) with the feature map $\phi(x)$, then the kernel matrix corresponding to this feature map is obtained by taking a pairwise inner products between the feauture maps of all data points, and this becomes the NTK at initilization.

Note that, in infinite-width neural network where weight of the model does not change much, the Jacobian of the model output can be seen as constant which makes the NTK constant during the training. This is referred to as the kernel regime and now training dynamics reduces to a very simple linear closed-from ordinary differential eqation.

Becuase our model is over-parameterized (p > n), the NTK is always positive definite and by performing a spectral decomposition, we can decouple the trajectory of the gradient flow into independent 1-d components (the eigenvectors) that decay at a rate proportional to the corresponding eigenvalue. The key thing is that they all decay (because all eigenvalues are positive), which means that the gradient flow always converges to the equillibrium where train loss is 0. This is the essence of most of the proofs in the recent papers which show that gradient descent achieves zero train loss.

Furthermore, NTK is deeply related to a generalization of a deep neural network, which is out of the scope here and not easy to understand. I recommed to read the post 1) for further detail, rather than reading the original paper.

2. Disentangling Trainability and Generalization in Deep Neural Networks

Author: Xiao et al.
URL: arxiv.org/pdf/1912.13053.pdf
Conference: ICML 2020

For wide networks, the trajectory under gradient descent is governed by the Neural Tangent Kernel (NTK), and for deep networks the NTK itself maintains only weak data dependence. By analyzing the spectrum of the NTK, we formulate necessary conditions for trainability and generalization across a range of  architectures in the limit of very wide and very deep networks, including Fully Connected Networks (FCNs) and Convolutional Neural Networks (CNNs). We identify several quantities related to the spectrum of the NTK that control trainability and generalization of deep networks and offer experimental evidence supporting their role in predicting the training and generalization performance of deep neural networks.

In practice, the correspondence between the NTK and neural networks is often broken due to, e.g., insufficient width, using a large learning rate, or changing the parameterization. Our theory does not directly apply to this setting. As such, developing an understanding of training and generalization away from the NTK regime still remains an important research direction.