Exact Computation with Infinitely Wide Neural Networks
📝 Original Paper Info
- Title: On Exact Computation with an Infinitely Wide Neural Net- ArXiv ID: 1904.11955
- Date: 2019-11-05
- Authors: Sanjeev Arora, Simon S. Du, Wei Hu, Zhiyuan Li, Ruslan Salakhutdinov, Ruosong Wang
📝 Abstract
How well does a classic deep net architecture like AlexNet or VGG19 classify on a standard dataset such as CIFAR-10 when its width --- namely, number of channels in convolutional layers, and number of nodes in fully-connected internal layers --- is allowed to increase to infinity? Such questions have come to the forefront in the quest to theoretically understand deep learning and its mysteries about optimization and generalization. They also connect deep learning to notions such as Gaussian processes and kernels. A recent paper [Jacot et al., 2018] introduced the Neural Tangent Kernel (NTK) which captures the behavior of fully-connected deep nets in the infinite width limit trained by gradient descent; this object was implicit in some other recent papers. An attraction of such ideas is that a pure kernel-based method is used to capture the power of a fully-trained deep net of infinite width. The current paper gives the first efficient exact algorithm for computing the extension of NTK to convolutional neural nets, which we call Convolutional NTK (CNTK), as well as an efficient GPU implementation of this algorithm. This results in a significant new benchmark for the performance of a pure kernel-based method on CIFAR-10, being $10\%$ higher than the methods reported in [Novak et al., 2019], and only $6\%$ lower than the performance of the corresponding finite deep net architecture (once batch normalization, etc. are turned off). Theoretically, we also give the first non-asymptotic proof showing that a fully-trained sufficiently wide net is indeed equivalent to the kernel regression predictor using NTK.💡 Summary & Analysis
This paper explores how neural networks maintain their performance and structure during training by showing that the weight matrices do not significantly deviate from their initial values. The authors demonstrate that even with minor changes in weights, the network can still achieve rapid convergence. This finding is crucial for understanding how neural networks retain their effectiveness throughout the learning process, emphasizing the importance of proper initialization.Key Summary
The paper demonstrates that during training, weight matrices in a neural network do not significantly deviate from their initial values and yet maintain high performance and rapid convergence.
Problem Statement
Neural networks need to converge quickly and improve performance, but large changes in weights can disrupt the original structure and performance of the model, leading to decreased generalization ability and inconsistent results.
Solution (Core Technology)
The paper proposes a method that ensures weight matrices do not deviate significantly from their initial values during training. Through analysis, it is proven that even with minor changes, the network still achieves rapid convergence.
Major Achievements
The study shows that neural networks can maintain high performance and rapid convergence without significant deviations from their initialized weights. This indicates that the network retains its structure while learning, approaching optimal solutions more closely.
Significance and Applications
This research highlights the importance of proper weight initialization in maintaining a neural network’s effectiveness during training. It provides insights into how to achieve stable and generalized learning processes, enhancing model performance and generalization ability.
📄 Full Paper Content (ArXiv Source)
In this section we derive CNTK for vanilla CNN. Given $`\vect{x} \in \mathbb{R}^{\nnw \times \nnh}`$ and $`(i,j) \in [\nnw]\times [\nnh]`$, we define
\phi_{ij}(\vect{x}) = [\vect{x}]_{i-(q-1)/2:i+(q-1)/2,j-(q-1)/2:j+(q-1)/2}
i.e., this operator extracts the $`(i,j)`$-th patch. By this definition, we can rewrite the CNN definition:
Let $`\vect{x}^{(0)} =\vect{x} \in \mathbb{R}^{\nnw\times \nnh \times \nnc^{(0)}}`$ be the input image where $`\nnc^{(0)}`$ is the number of channels in the input image.
For $`h=1,\ldots,H`$, $`\beta = 1,\ldots,\nnc^{(h)}`$, the intermediate outputs are defined as
\begin{align*}
\left[\tilde{\vect{x}}_{(\beta)}^{(h)}\right]_{ij} =
\sum_{\alpha=1}^{\nnc^{(h-1)}}\left\langle\mat{W}_{(\alpha),(\beta)}^{(h)},\vect\phi_{ij}\left({x}_{(\alpha)}^{(h-1)}\right)\right\rangle, \quad
% \sum_{\alpha=1}^{\nnc^{(h-1)}} \mat{W}_{(\alpha),(\beta)}^{(h)} \conv \vect{x}_{(\alpha)}^{(h-1)} ,\quad
% \end{align}
%% where $\mat{W}_{(\alpha),(\beta)}^{(h)} \sim \gauss(\vect{0},\mat{I}) \in \mathbb{R}^{q^{(h)} \times q^{(h)}}$
% and
% \begin{align*}
\vect{x}^{(h)}_{(\beta)} = \sqrt{\frac{c_{\sigma}}{\nnc^{(h)} \times q \times q}}\act{\tilde{\vect{x}}_{(\beta)}^{(h)}}
\end{align*}
where each $`\mat{W}_{(\alpha),(\beta)}^{(h)} \in \mathbb{R}^{q \times q}`$ is a filter with Gaussian initialization.
The final output is defined as
\begin{align*}
f(\params,\vect{x}) = \sum_{\alpha=1}^{\nnc^{(L)}} \left\langle \mat{W}_{(\alpha)}^{(L)},\vect{x}_{(\alpha)}^{(L)}\right\rangle
\end{align*}
where $`\mat{W}_{(\alpha)}^{(L)} \in \mathbb{R}^{\nnw \times \nnh}`$ is a weight matrix with Gaussian initialization.
Expansion of CNTK
We expand $`\Theta^{(L)}(\vect{x},\vect{x}')`$ to show we can write it as the sum of $`(L+1)`$ terms with each term representing the inner product between the gradients with respect to the weight matrix of one layer. We first define an linear operator
\begin{align}
\linop: \mathbb{R}^{\nnw\times\nnh\times\nnw\times\nnh} \rightarrow \mathbb{R}^{\nnw\times\nnh\times\nnw\times\nnh} \nonumber\\
\left[\linop\left(\mat{M}\right)\right]_{k\ell,k'\ell'} = \frac{c_{\sigma}}{q^2}\tr\left(\left[\mat{M}\right]_{\indset_{k\ell,k'\ell'}}\right). \label{eqn:linop}
\end{align}
This linear operator is induced from convolutional operation. And here use zero padding, namely when the subscription exceeds the range of $`[\nnw]\times [\nnh]\times[\nnw]\times [\nnh]`$, the value of the element is zero.
We also define $`\id\in \mathbb{R}^{\nnw\times\nnh\times\nnw\times\nnh}`$ as the identity tensor, namely $`\id_{i,j,i',j'}= \bm{1}\{i=i',j=j'\}.`$ And
\Sum{\mat{M}}=\sum_{(i,j,i',j') \in [\nnw] \times [\nnh]\times [\nnw]\times [\nnh] } \mat{M}_{i,j,i',j'}.
The following property of $`\linop`$ is immediate by definition: $`\forall \mat{M},\mat{N} \in \mathbb{R}^{\nnw\times\nnh\times\nnw\times\nnh}`$, we have
\begin{equation}
\label{eq:conjugate2}
\Sum{\mat{M}\odot \linop(\mat{N})} = \Sum{\linop(\mat{M})\odot \mat{N}}.
\end{equation}
With this operator, we can expand CNTK as (for simplicity we drop on $`\vect{x}`$ and $`\vect{x}'`$)
\begin{align*}
&\Theta^{(L)}\\
= &\tr\left(\dot{\mat{K}}^{(L)}\odot\Theta^{(H-1)}+\mat{K}^{(L)}\right) \\
= & \tr\left(\mat{K}^{(L)}\right) + \tr\left(\dot{\mat{K}}^{(L)}\odot \linop\left(\mat{K}^{(H-1)}\right)\right) + \tr\left(\dot{\mat{K}}^{(L)}\odot\linop\left(
\dot{\mat{K}}^{(H-1)} \odot \mat{\Theta}^{(H-2)}
\right)\right) \\
= & \ldots \\
= & \sum_{h=0}^{L} \tr\left(
\dot{\mat{K}}^{(L)} \odot \linop\left( \dot{\mat{K}}^{(H-1)} \linop \left(\cdots\dot{\mat{K}}^{(h+1)}\linop\left(\mat{K}^{(h)}\right)\cdots\right)\right) .
\right)
\end{align*}
Here for $`h=H`$, the term is just $`\tr\left(\mat{K}^{(L)}\right)`$.
In the following, we will show
\begin{align*}
%\expect_{\params \sim \mathcal{W}}
\left\langle \frac{\partial f(\params,\vect{x})}{\partial \mat{W}^{(h)}} , \frac{\partial f(\params,\vect{x}')}{\partial \mat{W}^{(h)}} \right\rangle
\approx
&\tr\left(
\dot{\mat{K}}^{(L)} \odot \linop\left( \dot{\mat{K}}^{(H-1)} \odot\linop \left(\cdots\dot{\mat{K}}^{(h)} \odot\linop\left(\mat{K}^{(h-1)}\right)\cdots\right)\right)
\right)\\
=& \Sum{ \id\odot \dot{\mat{K}}^{(L)} \odot \linop\left( \dot{\mat{K}}^{(H-1)} \odot\linop \left(\cdots\dot{\mat{K}}^{(h)} \odot\linop\left(\mat{K}^{(h-1)}\right)\cdots\right)\right) }.
\end{align*}
which could be rewritten as the following by Property [eq:conjugate2],
\begin{align*}
%\expect_{\params \sim \mathcal{W}}
\left\langle \frac{\partial f(\params,\vect{x})}{\partial \mat{W}^{(h)}} , \frac{\partial f(\params,\vect{x}')}{\partial \mat{W}^{(h)}} \right\rangle \approx
\Sum{
\linop\left(\mat{K}^{(h-1)}\right) \odot \dot{\mat{K}}^{(h)} \odot\linop \left(\dot{\mat{K}}^{(h+1)} \cdots \odot\linop\left(\id \odot \dot{\mat{K}}^{(L)}\right)\cdots\right)
}.
\end{align*}
Derivation
We first compute the derivative of the prediction with respect to one single filter.
\begin{align*}
\frac{\partial f(\params,\vect{x})}{\partial \mat{W}_{(\alpha),(\beta)}^{(h)}} = &\left\langle\frac{\partial f(\params,\vect{x})}{\partial \vect{x}_{(\beta)}}, \frac{\partial \vect{x}_{(\beta)}^{(h)}}{\partial \vect{W}_{(\alpha),(\beta)}^{(h)}}\right\rangle \\
= & \sum_{(i,j) \in [\nnw] \times [\nnh]} \left\langle\frac{\partial f(\params,\vect{x})}{[\vect{x}_{(\beta)}^{(h)}]_{ij}},\frac{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}{\partial \mat{W}_{(\alpha),(\beta)}^{(h)}}\right\rangle \\
= & \sum_{(i,j) \in [\nnw] \times [\nnh]} \frac{\partial f(\params,\vect{x})}{[\vect{x}_{(\beta)}^{(h)}]_{ij}} \sqrt{\frac{c_{\sigma}}{\nnc^{(h)}q^2}}\reluder{\left[
\tilde{\vect{x}}_{(\beta)}^{(h)}
\right]_{ij}}\phi_{ij}(\vect{x}_{(\alpha)}^{(h-1)}).
\end{align*}
With this expression, we proceed to we compute the inner product between gradients with respect to the $`h`$-th layer matrix
\begin{align}
&\sum_{\alpha=1}^{\nnc^{(h-1)}}\sum_{\beta=1}^{\nnc^{(h)}}\left\langle
\frac{\partial f(\params,\vect{x})}{\partial \mat{W}_{(\alpha),(\beta)}^{(h)}}, \frac{\partial f(\params,\vect{x}')}{\partial \mat{W}_{(\alpha),(\beta)}^{(h)}} \right\rangle \label{eqn:grad_prod}\\
= &\sum_{(i,j,i',j') \in [\nnw] \times [\nnh]\times [\nnw]\times [\nnh] }\frac{c_{\sigma}}{\nnc^{(h)}q^2}\sum_{\beta=1}^{\nnc^{(h)}}
\left(
\frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}}
\right)
\left(
\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}}
\right) \nonumber\\
&\cdot\left(\sum_{\alpha=1}^{\nnc^{(h-1)}}\left\langle\phi_{ij}(\vect{x}_{(\alpha)}^{(h-1)}),\phi_{i'j'}(\vect{x}'^{(h-1)}_{(\alpha)})\right\rangle\right).\nonumber
\end{align}
Similar to our derivation to NTK, we can use the following approximation
\begin{align*}
\left(\sum_{\alpha=1}^{\nnc^{(h-1)}}\left\langle\phi_{ij}(\vect{x}_{(\alpha)}^{(h-1)}),\phi_{i'j'}(\vect{x}'^{(h-1)}_{(\alpha)})\right\rangle\right) \approx
\tr\left(\left[\mat{K}^{(h-1)}\right]_{\indset_{ij,i'j'}}\right) = \linop\left(\mat{K}^{(h-1)}\right).
\end{align*}
Thus it remains to show that $`\forall (i,j,i',j') \in [\nnw] \times [\nnh]\times [\nnw]\times [\nnh]`$,
\begin{align*}
&\sum_{\beta=1}^{\nnc^{(h)}}\frac{c_{\sigma}}{\nnc^{(h)}q^2}
\left(
\frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}}
\right)
\left(
\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}}
\right)\\
\approx&
\left[\linop \left(\dot{\mat{K}}^{(h+1)} \cdots \odot\linop\left(\id \odot \dot{\mat{K}}^{(L)}\right)\cdots\right)\odot \dot{\mat{K}}^{(h)} \right]_{i,j,i',j'}
\end{align*}
The key step of this derivation is the following approximation (Equation [eq:gaussian_conditioning2]), which assumes for each $`(i,j,i',j')`$, $`\frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}}`$ and $`\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}}`$ are independent. This is used and made rigorous for ReLU activation and fully-connected networks in the proof of Theorem [thm:ntk_init]. gave a rigorous statement of this approximation in an asymptotic way for CNNs.
\begin{equation}
\begin{split}\label{eq:gaussian_conditioning2}
&\frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}} \left(
\frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}}
\right)
\left(
\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}}
\right) \\
\approx&
\left(\frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}}
\frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}}
\right)
\left(\frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}}
\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}}
\right)
\end{split}
\end{equation}
Note that
\begin{align*}
\frac{c_{\sigma}}{\nnc^{(h)}q^2}\sum_{\beta=1}^{\nnc^{(h)}}\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}} \approx \left[\dot{\mat{K}}^{(h)}\left(\vect{x},\vect{x}'\right)\right]_{ij,i'j'},
\end{align*}
the derivation is complete once we show
\begin{equation}
\label{eq:defi_G}\mat{G}^{(h)}(\vect{x},\vect{x'},\vect{\theta}):=\frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}}
\frac{\partial f(\params,\vect{x})}{\partial \vect{x}_{(\beta)}^{(h)}}\otimes \frac{\partial f(\params,\vect{x}')}{\partial \vect{x}'^{(h)}_{(\beta)}}
\approx
\linop \left(\dot{\mat{K}}^{(h+1)} \cdots \odot\linop\left(\id \odot \dot{\mat{K}}^{(L)}\right)\cdots\right).
\end{equation}
Now, we tackle the term $`\left( \frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}} \right)`$. Notice that
\begin{align*}
\frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\beta)}^{(h)}\right]_{ij}} = & \sum_{(k,\ell) \in [\nnw]\times [\nnh] }
\frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\gamma
)}^{(h+1)}\right]_{k\ell}}
\frac{\partial \left[\vect{x}_{(\gamma)}^{(h+1)}\right]_{k\ell}}{\partial \left[\vect{x}_{(\beta)}^{(h)}\right]_{ij}}.
\end{align*}
and for $`\gamma \in [\nnc^{(h+1)}]`$ and $`(k,\ell) \in [\nnw] \times [\nnh]`$
\begin{align*}
\frac{\partial \left[\vect{x}_{(\gamma)}^{(h+1)}\right]_{k\ell}}{\partial \left[\vect{x}_{(\beta)}^{(h)}\right]_{ij}} = \begin{cases}
\sqrt{\frac{c_{\sigma}}{\nnc^{(h+1)q^2}}}\reluder{\left[\tilde{\vect{x}}_{(\gamma)}^{(h+1)}\right]_{k\ell}} \left[\mat{W}_{(\beta),(\gamma)}^{(h+1)}\right]_{i-k+q-1,j-\ell+q-1}&\text{ if } (i,j) \in \indset_{k\ell} \\
0 &\text{ otherwise }
\end{cases}.
\end{align*}
We then have
\begin{align}
&\left[\mat{G}^{(h)}(\vect{x},\vect{x'},\vect{\theta})\right]_{ij,i'j'} = \frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}} \frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\beta)}^{(h)}\right]_{ij}} \frac{\partial f(\params,\vect{x}')}{\partial \left[\vect{x'}^{(h)}_{(\beta)}\right]_{i'j'}} \nonumber \\
=&\sum_{k,\ell,k',\ell'} \frac{c_{\sigma}}{\nnc^{(h+1)}q^2}\sum_{\gamma=1}^{\nnc^{(h+1)}}\left(
\frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\gamma
)}^{(h+1)}\right]_{k\ell}}
\frac{\partial f(\params,\vect{x}')}{\partial \left[\vect{x}_{(\gamma
)}^{('h+1)}\right]_{k'\ell'}}\right) \left(\reluder{\left[\tilde{\vect{x}}^{(h+1)}_{(\gamma)}\right]_{k\ell}}\reluder{\left[\tilde{\vect{x}'}^{(h+1)}_{(\gamma)}\right]_{k'\ell'}}\right)\nonumber\\
& \cdot \frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}} \indict\left\{(i,j,i',j')\in \indset_{k\ell,k'\ell'}\right\} \left[\mat{W}_{(\beta),(\gamma)}^{(h+1)}\right]_{i-k+q-1,j-\ell+q-1}\left[\mat{W}_{(\beta),(\gamma)}^{(h+1)}\right]_{i'-k'+q-1,j'-\ell'+q-1}\nonumber\\
\approx & \sum_{k,\ell,k',\ell'} \frac{c_{\sigma}}{\nnc^{(h+1)}q^2}\sum_{\gamma=1}^{\nnc^{(h+1)}}\left(
\frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\gamma
)}^{(h+1)}\right]_{k\ell}}
\frac{\partial f(\params,\vect{x}')}{\partial \left[\vect{x}_{(\gamma
)}^{('h+1)}\right]_{k'\ell'}}\right) \left(\reluder{\left[\tilde{\vect{x}}^{(h+1)}_{(\gamma)}\right]_{k\ell}}\reluder{\left[\tilde{\vect{x}'}^{(h+1)}_{(\gamma)}\right]_{k'\ell'}}\right)\nonumber\\
& \cdot \indict\left\{(i,j,i',j')\in \indset_{k\ell,k'\ell'}, i-k=i'-k',j-\ell=j'-\ell'\right\}\nonumber\\
\approx & \sum_{k,\ell,k',\ell'} \left(\frac{1}{\nnc^{(h+1)}}\sum_{\gamma=1}^{\nnc^{(h+1)}}
\frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\gamma
)}^{(h+1)}\right]_{k\ell}}
\frac{\partial f(\params,\vect{x}')}{\partial \left[\vect{x}_{(\gamma
)}^{('h+1)}\right]_{k'\ell'}}\right) \left(\frac{c_\sigma}{q^2\nnc^{(h+1)}}\sum_{\gamma=1}^{{\nnc^{(h+1)}}}\reluder{\left[\tilde{\vect{x}}^{(h+1)}_{(\gamma)}\right]_{k\ell}}\reluder{\left[\tilde{\vect{x}'}^{(h+1)}_{(\gamma)}\right]_{k'\ell'}}\right)\nonumber\\
& \cdot \indict\left\{(i,j,i',j')\in \indset_{k\ell,k'\ell'}, i-k=i'-k',j-\ell=j'-\ell'\right\} \nonumber\\
\approx & \sum_{k,\ell,k',\ell'} \left(\frac{1}{\nnc^{(h+1)}}\sum_{\gamma=1}^{\nnc^{(h+1)}}
\frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\gamma
)}^{(h+1)}\right]_{k\ell}}
\frac{\partial f(\params,\vect{x}')}{\partial \left[\vect{x}_{(\gamma
)}^{('h+1)}\right]_{k'\ell'}}\right) \left[\dot{\mat{K}}^{(h+1)}\left(\vect{x},\vect{x}'\right)\right]_{\ell k,\ell'k'}\nonumber\\
& \cdot \indict\left\{(i,j,i',j')\in \indset_{k\ell,k'\ell'}, i-k=i'-k',j-\ell=j'-\ell'\right\}\nonumber\\
\approx & \tr\left( \left[ \mat{G}^{(h+1)}(\vect{x},\vect{x'},\vect{\theta}) \odot\dot{\mat{K}}^{(h+1)}\left(\vect{x},\vect{x}'\right)\right]_{D_{ij,i'j'}}\right)
\label{eqn:grad_fxh}
\end{align}
where the first approximation is due to our initialization of $`\mat{W}^{(h+1)}`$. In other words, we’ve shown
\begin{equation}
\label{eq:recursion_G} \mat{G}^{(h)}(\vect{x},\vect{x'},\vect{\theta}) = \linop\left(\mat{G}^{(h+1)}(\vect{x},\vect{x'},\vect{\theta})\odot \dot{\mat{K}}^{(h+1)}\left(\vect{x},\vect{x}'\right)\right).
\end{equation}
Since we use a fully-connected weight matrix as the last layer, we have $`\mat{G}^{(L)}(\vect{x},\vect{x'},\vect{\theta}) \approx \id`$.
Thus by induction with Equation [eq:recursion_G], we have derived Equation [eq:defi_G], which completes the derivation of CNTK.
For the derivation of CNTK-GAP, the only difference is due to the global average pooling layer(GAP), $`\mat{G}^{(L)}(\vect{x},\vect{x'},\vect{\theta}) \approx \frac{1}{\nnh^2\nnw^2} \bm{1}\otimes\bm{1}`$, where $`\bm{1}\otimes\bm{1}\in \mathbb{R}^{\nnw\times\nnh\times\nnw\times\nnh}`$ is the all one tensor.
How well does a classic deep net architecture like AlexNet or VGG19 classify on a standard dataset such as CIFAR-10 when its “width”— namely, number of channels in convolutional layers, and number of nodes in fully-connected internal layers — is allowed to increase to infinity? Such questions have come to the forefront in the quest to theoretically understand deep learning and its mysteries about optimization and generalization. They also connect deep learning to notions such as Gaussian processes and kernels. A recent paper introduced the Neural Tangent Kernel (NTK) which captures the behavior of fully-connected deep nets in the infinite width limit trained by gradient descent; this object was implicit in some other recent papers. An attraction of such ideas is that a pure kernel-based method is used to capture the power of a fully-trained deep net of infinite width.
The current paper gives the first efficient exact algorithm for computing the extension of NTK to convolutional neural nets, which we call Convolutional NTK (CNTK), as well as an efficient GPU implementation of this algorithm. This results in a significant new benchmark for performance of a pure kernel-based method on CIFAR-10, being $`10\%`$ higher than the methods reported in , and only $`6\%`$ lower than the performance of the corresponding finite deep net architecture (once batch normalization etc. are turned off). Theoretically, we also give the first non-asymptotic proof showing that a fully-trained sufficiently wide net is indeed equivalent to the kernel regression predictor using NTK.
In this section we define CNNs studied in this paper and derive the corresponding CNTKs.
Full-connected NN
We first define fully-connected neural network and its corresponding gradient.
\begin{align}
f(\params,\vect{x}) = \vect{a}^\top \relu{\mat{W}_H\relu{\mat{W}_{H-1}\ldots\relu{\mat{W}_1\vect{x}}}} \label{eqn:fc}
\end{align}
We can write the gradient in a compact form
\begin{align}
\frac{\partial f(\params,\vect{x})}{\partial \mat{W}_{h}} = \back_{h+1}(x)^\top \left(\vect{x}^{(h-1)}\right)^\top
\end{align}
where
\begin{align}
\back_{h}(\vect{x}) = \vect{a}^\top \mat{D}_{H}(\vect{x})\mat{W}_{H}\cdots \mat{D}_h(\vect{x})\mat{W}_{h} \in \mathbb{R}^{1 \times m} \\
\mat{D}_h = \diag\left(\left[\dot{\sigma}\left(\mat{W}_h\vect{x}^{(h-1)}\right)\right]\right) \in \mathbb{R}^{m \times m}
\end{align}
and
\begin{align}
\vect{x}^{(h)} = \mat{W}^{(h)}\relu{\mat{W}^{(h)}\cdot\relu{\mat{W}^{(1)}\vect{x}}}
\end{align}
CNN
We consider the follow two dimensional convolutional neural networks.
-
Input $`\vect{x}^{(0)} \in \mathbb{R}^{\nnw\times \nnh \times \nnc^{(0)}}`$ where $`\nnw`$ is the width, $`\nnh`$ is the height and $`\nnc`$ is the initial channel.
-
For $`\ell=1,\ldots,L`$, $`\beta = 1,\ldots,\nnc^{(\ell)}`$, the intermediate inputs are defined as
MATH\begin{align} \tilde{\vect{x}}_{(\beta)}^{(\ell)} = \sum_{\alpha=1}^{\nnc^{(\ell-1)}} \mat{W}_{(\alpha),(\beta)}^{(\ell)} \conv \vect{x}_{(\beta)}^{(\ell-1)} \in \mathbb{R}^{\nnw \times \nnh} \end{align}Click to expand and view morewhere $`\mat{W}_{(\alpha),(\beta)}^{(\ell)} \sim \gauss(\vect{0},\mat{I}) \in \mathbb{R}^{q^{(\ell)} \times q^{(\ell)}}`$ and
MATH\begin{align} \vect{x}^{(\ell)}_{(\beta)} = \sqrt{\frac{c_{\sigma}}{\nnc^{(\ell) \times q^{(\ell)} \times q^{(\ell)}}}}\act{\tilde{\vect{x}}_{(\beta)}^{(\ell)}}. \end{align}Click to expand and view more -
The final output is defined as
MATH\begin{align} f(\vect{x}^{(0)}) = \sum_{\beta=1}^{\nnc^{(L)}} \left\langle \mat{W}_{(\beta)}^{(L)},\vect{x}_{(\beta)}^{(L)}\right\rangle. \end{align}Click to expand and view more
The convolution operator is defined as for $`\mat{W} \in \mathbb{R}^{q\times q}`$ and $`\vect{x} \in \mathbb{R}^{\nnw \times \nnh}`$
\begin{align}
[\vect{w}\conv \vect{x}]_{ij} = \sum_{(a,b) = (-\frac{q-1}{2}, -\frac{q-1}{2})}^{(-\frac{q+1}{2}, -\frac{q+1}{2}))}\vect{x}_{i+1,j+b} \cdot \mat{W}_{a+\frac{q+1}{2},b+\frac{q+1}{2}}.
\end{align}
In this subsection we prove the following lemma.
Fix $`\omega \le\poly(1/L,1/n,1/\log(1/\delta), \lambda_0)`$. Suppose we set $`m \ge \poly(1/\omega)`$ and $`\kappa \le 1`$. Then with probability at least $`1-\delta`$ over random initialization, we have for all $`t \ge 0`$, for any $`(\vect{x},\vect{x}') \in \left\{\vect{x}_1,\ldots,\vect{x}_n,\vect{x}_{te}\right\} \times \left\{\vect{x}_1,\ldots,\vect{x}_n,\vect{x}_{te}\right\}`$
\begin{equation*}
\abs{
\kernel_{t}\left(\vect{x},\vect{x}'\right)- \kernel_{0}\left(\vect{x},\vect{x}'\right)
} \le \omega
\end{equation*}
Recall for any fixed $`\vect{x}`$ and $`\vect{x}'`$, Theorem [thm:ntk_init] shows $`\abs{\kernel_{0}(\vect{x},\vect{x}')-\kernel_{ntk}(\vect{x},\vect{x}')} \le \epsilon`$ if $`m`$ is large enough. The next lemma shows we can reduce the problem of bounding the perturbation on the kernel value to the perturbation on the gradient.
If $`\norm{\frac{\partial f(\params(t),\vect{x})}{\partial \params} - \frac{\partial f(\params(0),\vect{x})}{\partial \params}} \le \epsilon`$ and $`\norm{\frac{\partial f(\params(t),\vect{x}')}{\partial \params} - \frac{\partial f(\params(0),\vect{x}')}{\partial \params}} \le \epsilon`$, we have
\begin{align*}
\abs{\kernel_{t}(\vect{x},\vect{x}') - \kernel_{0}(\vect{x},\vect{x}')} \le O\left(\epsilon\right)
\end{align*}
Proof. By the proof of Theorem [thm:ntk_init], we know $`\norm{\frac{\partial f(\params(0),\vect{x})}{\partial \params}}_2 = O\left(1\right)`$. Then we can just use triangle inequality. ◻
Now we proceed to analyze the perturbation on the gradient. Note we can focus on the perturbation on a single sample $`\vect{x}`$ because we can later take a union bound. Therefore, in the rest of this section, we drop the dependency on a specific sample. We use the following notations in this section. Recall $`\mat{W}^{(1)},\ldots,\mat{W}^{(L+1)} \sim \gauss\left(\mat{0},\mat{I}\right)`$ and we denote $`\diff \mat{W}^{(1)},\ldots,\diff \mat{W}^{(L+1)}`$ the perturbation matrices. We let $`\widetilde{\mat{W}}^{(h)} = \mat{W}^{(h)} + \diff\mat{W}^{(h)}`$. We let $`\tilde{\vect{g}}^{(0)} = \vect{g}^{(0)}=\vect{x}`$ and for $`h=1,\ldots,L`$ we define
\begin{align*}
\vect{z}^{(h)} = &\sqrt{\frac{2}{m}}\mat{W}^{(h)}
\vect{g}^{(h-1)}, ~~~ \vect{g}^{(h)} = \relu{\vect{z}^{(h)}},\\
\tilde{\vect{z}}^{(h)} = &\sqrt{\frac{2}{m}}\widetilde{\mat{W}}^{(h)}
\tilde{\vect{g}}^{(h-1)}, ~~~ \tilde{\vect{g}}^{(h)} = \relu{\tilde{\vect{z}}^{(h)}}.
\end{align*}
For $`h=1,\ldots,L`$, $`i=1,\ldots,m`$, we denote
\begin{align*}
[\mat{D}^{(h)}]_{ii} =
&\indict\left\{\left[\mat{W}^{(h)}\right]_{i,:}\vect{g}^{(h-1)}\ge 0\right\}\\
[\widetilde{\mat{D}}^{(h)}]_{ii} =
&\indict\left\{\left[\widetilde{\mat{W}}^{(h)}\right]_{i,:}\tilde{\vect{g}}^{(h-1)}\ge 0\right\}.
\end{align*}
Note $`\vect{z}^{(h)} = \sqrt{\frac{2}{m}}\vect{f}^{(h)}`$. Here we use $`\vect{z}^{(h)}`$ instead of $`\vect{f}^{(h)}`$ for the ease of presentation.
For convenience, we also define
\begin{align*}
\diff \mat{D}^{(h)} = \widetilde{\mat{D}}^{(h)} - \mat{D}^{(h)}.
\end{align*}
Recall the gradient to $`\mat{W}^{(h)}`$ is:
\begin{align*}
\frac{\partial f(\params,\vect{x})}{\partial \mat{W}^{(h)}} = \vect{b}^{(h)} \left(\vect{g}^{(h-1)}\right)^\top
\end{align*}
Similarly, we have
\begin{align*}
\frac{\partial f(\params,\vect{x})}{\partial \widetilde{\mat{W}}^{(h)}} = \tilde{\vect{b}}^{(h)} \left(\tilde{\vect{g}}^{(h-1)}\right)^\top
\end{align*}
where
\begin{align*}
\tilde{\vect{b}}^{(h)} = \begin{cases}
1 &\text{ if } h = L+1\\
\sqrt{\frac{2}{m}}\widetilde{\mat{D}}^{(h)}\left(\widetilde{\mat{W}}^{(h+1)}\right)^\top\tilde{\vect{b}}^{(h+1)} &\text{ Otherwise}
\end{cases} .
\end{align*}
This gradient formula allows us to bound the perturbation on $`\diff \vect{g}^{(h)}\triangleq\tilde{\vect{g}}^{(h)} - \vect{g}^{(h)}`$ and $`\diff \vect{b}^{(h)}\triangleq\tilde{\vect{b}}^{(h)}-\vect{b}^{(h)}`$ separately. The following lemmas adapted from show with high probability over the initialization, bounding the perturbation on $`\diff \vect{g}^{(h)}`$ and $`\diff \vect{b}^{(h)}`$ can be reduced to bounding the perturbation on weight matrices.
Suppose
\omega\le \poly\left(1/n,\lambda_0,1/L,1/\log(m),\epsilon,1/\log(1/\delta)\right).
Then with probability at least $`1-\delta`$ over random initialization, if $`\norm{\diff \mat{W}^{(h)}}_2 \le \sqrt{m}\omega`$ for all $`h=1,\ldots,L`$, we have $`\norm{\diff \vect{g}^{(h)}}_2 = O(\omega L^{5/2}\sqrt{\log m})`$ for all $`h=1,\ldots,L`$.
While did not consider the perturbation on $`\mat{W}^{(1)}`$, by scrutinizing their proof, it is easy to see that the perturbation bounds still hold even if there is a small perturbation on $`\mat{W}^{(1)}`$.
The next lemma bounds the backward vector, adapted from
Suppose
\omega\le \poly\left(1/n,\lambda_0,1/L,1/\log(m),\epsilon,1/\log(1/\delta)\right).
Then with probability at least $`1-\delta`$ over random initialization,if $`\norm{\diff \mat{W}^{(h)}}_2 \le \sqrt{m}\omega`$ for all $`h=1,\ldots,L+1`$, we have for all $`h=1,\ldots,L+1`$, $`\norm{\tilde{\vect{b}}^{(h)}-\vect{b}^{(h)}}_2 = O\left(\omega^{1/3}L^2\sqrt{\log m}\right)`$.
While did not consider the perturbation on $`\mat{W}^{(L+1)}`$, by scrutinizing their proof, it is easy to see that the perturbation bounds still hold even if there is a small perturbation on $`\mat{W}^{(L+1)}`$.
Combing these two lemmas and the result for the initialization (Theorem [thm:ntk_init]), we have the following “gradient-Lipschitz" lemma.
Suppose $`\omega \le \poly\left(1/n,\lambda_0,1/L,1/\log(m),\epsilon,1/\log(1/\delta)\right).`$ Then with probability at least $`1-\delta`$ over random initialization, if $`\norm{\diff \mat{W}^{(h)}}_2 \le \sqrt{m}\omega`$ for all $`h=1,\ldots,L+1`$, we have for all $`h=1,\ldots,L+1`$:
\begin{align*}
\norm{\tilde{\vect{b}}^{(h)}\left(\tilde{\vect{g}}^{(h-1)}\right)^\top-\vect{b}^{(h)}\left(\vect{g}^{(h-1)}\right)^\top}_F = O\left(\omega^{1/3}L^{5/2}\sqrt{\log m}\right)
\end{align*}
Proof. We use the triangle inequality to bound the perturbation
\begin{align*}
&\norm{\tilde{\vect{b}}^{(h)}\left(\tilde{\vect{g}}^{(h-1)}\right)^\top-\vect{b}^{(h)}\left(\vect{g}^{(h-1)}\right)^\top}_F \\
\le& \norm{\tilde{\vect{b}}^{(h)}\left(\tilde{\vect{g}}^{(h-1)}\right)^\top-\vect{b}^{(h)}\left(\tilde{\vect{g}}^{(h-1)}\right)^\top}_F + \norm{\vect{b}^{(h)}\left(\tilde{\vect{g}}^{(h-1)}\right)^\top-\vect{b}^{(h)}\left(\vect{g}^{(h-1)}\right)^\top}_F \\
\le & \norm{
\diff \vect{b}^{(h)}\left(\vect{g}^{(h-1)}+\diff \vect{g}^{(h-1)}\right)^\top
}_F + \norm{\vect{b}^{(h)}\left(\diff \vect{g}^{(h-1)}\right)^\top}_F\\
= & O\left(\omega^{1/3}L^{5/2}\sqrt{\log m}\right).
\end{align*}
◻
The following lemma shows for given weight matrix, if we have linear convergence and other weight matrices are only perturbed by a little, then the given matrix is only perturbed by a little as well.
Fix $`h \in [L+1]`$ and a sufficiently small $`\omega \le \poly\left(1/n,\lambda_0,1/L,1/\log(m),\epsilon,1/\log(1/\delta),\kappa\right).`$ Suppose for all $`t \ge 0`$, $`\norm{\vect{u}_{nn}(t)-\vect{y}}_2 \le \exp\left(-\frac12\kappa^2\lambda_0t\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2`$ and $`\norm{\mat{W}^{(h')}(t)-\mat{W}^{(h')}(0)}_F \le \omega \sqrt{m}`$ for $`h'\neq h`$. Then if $`m \ge \poly\left(1/\omega\right)`$ we have with probability at least $`1-\delta`$ over random initialization, for all $`t \ge 0`$
\begin{align*}
\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F = O \left( \frac{\sqrt{n}}{\lambda_0}\right)\le\omega\sqrt{m}.
\end{align*}
Proof. We let $`C,C_0, C_1, C_2, C_3 > 0`$ be some absolute constants.
\begin{align*}
&\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F\\
= &\norm{\int_{0}^{t}\frac{d \mat{W}^{(h)}(\tau)}{d\tau} d\tau}_F \\
=& \norm{\int_{0}^{t}\frac{\partial L(\params(\tau))}{\partial \mat{W}^{(h)}(\tau)} d\tau}_F \\
= & \norm{\int_{0}^{t}\frac{1}{n}\sum_{i=1}^n \left(u_i(\tau)-y_i\right) \frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}} d\tau}_F \\
\le & \frac{1}{n}\max_{0\le \tau\le t}\sum_{i=1}^n\norm{ \frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F\int_{0}^{t} \norm{\vect{u}_{nn}(\tau)-\vect{y}}_2 d\tau\\
\le & \frac{1}{n}\max_{0\le \tau\le t}\sum_{i=1}^n\norm{ \frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F\int_{0}^{t} \exp\left(-\kappa^2\lambda_0\tau\right) d\tau\norm{\vect{u}_{nn}(0)-\vect{y}}_2 \\
\le & \frac{C_0}{\sqrt{n}\lambda_0}\max_{0\le \tau\le t}\sum_{i=1}^n\norm{ \frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F &\\
\le &\frac{C_0}{\sqrt{n}\kappa^2\lambda_0}\max_{0\le \tau\le t}\sum_{i=1}^n\left(\norm{ \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F + \norm{\frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}- \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F\right)\\
\le &\frac{C_1}{\sqrt{n}\lambda_0}\max_{0\le \tau\le t}\sum_{i=1}^n\left(\norm{ \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F + \norm{\frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}- \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F\right) \\
\le &\frac{C_2\sqrt{n}}{\lambda_0}+\frac{C_1\sqrt{n}}{\lambda_0}\max_{0\le \tau\le t}\left( \norm{\frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}- \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F\right).
\end{align*}
The last step we used $`\norm{ \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F = O(1)`$. Suppose there exists $`t`$ such that $`\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F > \omega\sqrt{m}.`$ Denote
t_0 = \argmin_{t\ge 0}\left\{\norm{\mat{
W
}^{(h)}(t)-\mat{W}^{(h)}(0)}_F > \omega\sqrt{m}.
\right\}.
For any $`t < t_0`$, we know for all $`h' \in [L+1]`$, $`\norm{\mat{W}^{(h')}(t)-\mat{W}^{(h')}(0)}_2 \le \omega\sqrt{m}.`$ Therefore, by Lemma [lem:grad_lip], we know
\begin{align*}
\norm{\frac{\partial f_{nn}(\params(t),\vect{x}_i)}{\partial \mat{W}^{(h)}}- \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F = C \omega^{1/3}L^{5/2}.
\end{align*}
Therefore, using the fact that $`\omega`$ is sufficiently small we can bound
\begin{align*}
\norm{\mat{W}^{(h)}(t_0)-\mat{W}^{(h)}(0)}_F \le \frac{C_3\sqrt{n}}{\lambda_0}.
\end{align*}
Since we also know $`m`$ is sufficiently large to make $`\omega\sqrt{m} > \frac{C_3\sqrt{n}}{\lambda_0}`$, we have a contradiction. ◻
The next lemma shows if all weight matrices only have small perturbation, then we still have linear convergence.
Suppose $`\omega = \poly\left(1/n,\lambda_0,1/L,1/\log(m),\epsilon,1/\log(1/\delta),\kappa\right).`$ Suppose for all $`t \ge 0`$ $`\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F \le \omega \sqrt{m}`$ for $`h \in [L+1]`$. Then if $`m = \poly\left(1/\omega\right)`$, we have with probability at least $`1-\delta`$ over random initialization, for all $`t \ge 0`$
\begin{align*}
\norm{\vect{u}_{nn}(t)-\vect{y}}_2 \le \exp\left(-\frac{1}{2}\kappa^2\lambda_0t\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2.
\end{align*}
Proof. Under this assumption and the result of initialization, we know for all $`t \ge 0`$, $`\lambda_{\min}\left(\trainker(t)\right) \ge \frac{1}{2}\lambda_0`$. This in turn directly imply the linear convergence result we want. ◻
Lastly, with these lemmas at hand, using an argument similar to , we can show during training, weight matrices do not move by much.
Let $`\omega \le \poly(\eps,L,\lambda_0,1/\log(m),1/\log(1/\delta),\kappa, 1/n)`$. If $`m \ge \poly(1/\omega)`$, then with probability at least $`1-\delta`$ over random initialization, we have for all $`t \ge 0`$, for all $`h \in [L+1]`$ we have
\begin{align*}
\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F \le \omega\sqrt{m}
\end{align*}
and
\begin{align*}
\norm{\vect{u}_{nn}(t)-\vect{y}}_2 \le \exp\left(-\frac{1}{2}\kappa^2\lambda_0t\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2.
\end{align*}
Proof. Let
\begin{align*}
t_0 = \argmin_{t}\left\{\exists h \in [L+1], \norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F > \omega\sqrt{m} \right.\\
\left.\text{ or }\norm{\vect{u}_{nn}(t)-\vect{y}}_2 > \exp\left(-\frac{1}{2}\kappa^2\lambda_0t\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2\right\}.
\end{align*}
We analyze case by case. Suppose at time $`t_0`$, $`\norm{\mat{W}^{(h)}(t_0)-\mat{W}^{(h)}(0)}_F > \omega\sqrt{m}`$. By Lemma [lem:lin_conv_w_perb_small], we know there exists some $`0\le t_1 < t_0`$ such that either there exists $`h' \neq h`$ such that
\norm{\mat{W}^{(h')}(t_1)-\mat{W}^{(h')}(0)}_F > \omega\sqrt{m}
or
\norm{\vect{u}_{nn}(t_1)-\vect{y}}_2 > \exp\left(-\frac{1}{2}\kappa^2\lambda_0t_1\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2.
However, this violates the minimality of $`t_0`$. For the other case, if
\norm{\vect{u}_{nn}(t_0)-\vect{y}}_2 > \exp\left(-\frac{1}{2}\kappa^2\lambda_0t_0\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2,
By Lemma [lem:w_perb_small_lin_conv], we know there exists $`t_1 < t_0`$ such that there exists $`h \in [L+1]`$,
\norm{\mat{W}^{(h)}(t_1)-\mat{W}^{(h)}(0)}_F > \omega\sqrt{m}.
However, again this violates the minimality of $`t_0`$. ◻
Now we can finish the proof of Lemma [lem:ker_perb_train].
Proof of Lemma [lem:ker_perb_train]. By Lemma [lem:cont_induction], we know for $`t \rightarrow \infty`$, $`\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F \le O\left(\omega\sqrt{m}\right)`$ for if $`\omega`$ is sufficiently. Applying Lemma [lem:grad_lip], we know we only have a small perturbation on the gradient. Applying Lemma [lem:grad_to_kernel], we know we only have small perturbation on kernel values. ◻
From a Gaussian process (GP) viewpoint, the correspondence between infinite neural networks and kernel machines was first noted by . Follow-up work extended this correspondence to more general shallow neural networks . More recently, this was extended to deep and convolutional neural networks and a variety of other architectures . However, these kernels, as we discussed in Section 1, represent weakly-trained nets, instead of fully-trained nets.
Beyond GPs, the connection between neural networks and kernels is also studied in the compositional kernel literature. derived a closed-form kernel formula for rectified polynomial activations, which include ReLU as a special case. proposed a general framework to transform a neural network to a compositional kernel and later showed for sufficiently wide neural networks, stochastic gradient descent can learn functions that lie in the corresponding reproducing kernel Hilbert space. However, the kernels studied in these works still correspond to weakly-trained neural networks.
This paper is inspired by a line of recent work on over-parameterized neural networks . These papers established that for (convolutional) neural networks with large but finite width, (stochastic) gradient descent can achieve zero training error. A key component in these papers is showing that the weight matrix at each layer is close to its initialization. This observation implies that the kernel defined in Equation [eqn:ntk] is still close to its initialization. explicitly used this observation to derive generalization bounds for two-layer over-parameterized neural networks. argued that these results in the kernel regime may be too simple to be able to explain the success of deep learning, while on the other hand, out results show that CNTK is at least able to perform well on tasks like CIFAR-10 classification. Also see the survey for recent advance in deep learning theory.
derived the exact same kernel from kernel gradient descent. They showed that if the number of neurons per layer goes to infinity in a sequential order, then the kernel remains unchanged for a finite training time. They termed the derived kernel Neural Tangent Kernel (NTK). We follow the same naming convention and name its convolutional extension Convolutional Neural Tangent Kernel (CNTK). Later, derived a formula of CNTK as well as a mechanistic way to derive NTK for different architectures. Comparing with , our CNTK formula has a more explicit convolutional structure and results in an efficient GPU-friendly computation method. Recently, tried to empirically verify the theory in by studying the linearization of neural nets. They observed that in the first few iterations, the linearization is close to the actual neural net. However, as will be shown in Section 8, such linearization can decrease the classification accuracy by $`5\%`$ even on a “CIFAR-2" (airplane V.S. car) dataset. Therefore, exact kernel evaluation is important to study the power of NTK and CNTK.
In this section we define CNN with global average pooling considered in this paper and its corresponding CNTK formula.
CNN definition.
Let $`\vect{x} = \vect{x}^{(0)} \in \mathbb{R}^{\nnw\times \nnh \times \nnc^{(0)}}`$ be the input image and $`\nnc^{(0)}`$ is the number of initial channels.
For $`h=1,\ldots,L`$, $`\beta = 1,\ldots,\nnc^{(h)}`$, the intermediate outputs are defined as
\begin{align*}
\tilde{\vect{x}}_{(\beta)}^{(h)} = \sum_{\alpha=1}^{\nnc^{(h-1)}} \mat{W}_{(\alpha),(\beta)}^{(h)} \conv \vect{x}_{(\alpha)}^{(h-1)} ,\quad
% \end{align}
%% where $\mat{W}_{(\alpha),(\beta)}^{(h)} \sim \gauss(\vect{0},\mat{I}) \in \mathbb{R}^{q^{(h)} \times q^{(h)}}$
% and
% \begin{align*}
\vect{x}^{(h)}_{(\beta)} = \sqrt{\frac{c_{\sigma}}{\nnc^{(h) \times q^{(h)} \times q^{(h)}}}}\act{\tilde{\vect{x}}_{(\beta)}^{(h)}}.
\end{align*}
The final output is defined as
\begin{align*}
f(\params,\vect{x}) =\sum_{\alpha=1}^{\nnc^{(L)}} W_{(\alpha)}^{(L + 1)}\left( \frac{1}{\nnw \nnh} \sum_{(i,j) \in [\nnw] \times [\nnh]}\left[\vect{x}_{(\alpha)}^{(L)}\right]_{ij}\right).
\end{align*}
where $`W_{(\alpha)}^{(L + 1)} \in \mathbb{R}`$ is a scalar with Gaussian initialization.
Besides using global average pooling, another modification is that we do not train the first and the layer. This is inspired by in which authors showed that if one applies gradient flow, then at any training time $`t`$, the difference between the squared Frobenius norm of the weight matrix at time $`t`$ and that at initialization is same for all layers. However, note that $`\mat{W}^{(1)}`$ and $`\mat{W}^{(L+1)}`$ are special because they are smaller matrices compared with other intermediate weight matrices, so relatively, these two weight matrices change more than the intermediate matrices during the training process, and this may dramatically change the kernel. Therefore, we choose to fix $`\mat{W}^{(1)}`$ and $`\mat{W}^{(L+1)}`$ to the make over-parameterization theory closer to practice.
CNTK formula.
We let $`\vect{x},\vect{x}'`$ be two input images. Note because CNN with global average pooling and vanilla CNN shares the same architecture except the last layer, $`\mat{\Sigma}^{(h)}(\vect{x},\vect{x}')`$, $`\dot{\mat{\Sigma}}^{(h)}(\vect{x},\vect{x}')`$ and $`\mat{K}^{(h)}(\vect{x},\vect{x}')`$ are the same for these two architectures. the only difference is in calculating the final kernel value. To compute the final kernel value, we use the following procedure.
First, we define $`\mat{\Theta}^{(0)}(\vect{x},\vect{x}') = \vect{0}`$. Note this is different from CNTK for vanilla CNN which uses $`\mat{\Sigma}^{(0)}`$ as the initial value because we do not train the first layer.
For $`h=1,\ldots,L - 1`$ and $`(i,j,i',j') \in [\nnw] \times [\nnh] \times [\nnw] \times [\nnh]`$, we define
\begin{align*}
\left[\mat{\Theta}^{(h)}(\vect{x},\vect{x}')\right]_{ij,i'j'} = \tr\left(\left[\dot{\mat{K}}^{(h)}(\vect{x},\vect{x}')\odot\mat{\Theta}^{(h-1)}(\vect{x},\vect{x}')+\mat{K}^{(h)}(\vect{x},\vect{x}')\right]_{D_{ij,i'j'}}\right).
% \label{eqn:vanila_cnn_gradient_kernel}
\end{align*}
For $`h=L`$, we define $`\mat{\Theta}^{(L)}(\vect{x},\vect{x}') = \dot{\mat{K}}^{(L)}(\vect{x},\vect{x}')\odot\mat{\Theta}^{(L-1)}(\vect{x},\vect{x}')`$.
Lastly, the final kernel value is defined as
% \mat{{\Theta}}^{(H)}(\vect{x},\vect{x}') =
%\Theta^{(H)}(\vect{x},\vect{x}')=
\frac{1}{\nnw^2 \nnh^2} \sum_{(i,j,i',j') \in [\nnw] \times [\nnh] \times [\nnw] \times [\nnh]} \left[\mat{\Theta}^{(L)}(\vect{x},\vect{x}')\right]_{ij,i'j'} .
% +\dot{\mat{K}}^{(H)}(\vect{x},\vect{x}')
%\right]_{ij,i'j'}.
Note that we ignore $`\mat{K}^{(L)}`$ comparing with the CNTK of CNN. This is because we do not train the last layer. The other difference is we calculate the mean over all entries, instead of calculating the summation over the diagonal ones. This is because we use global average pooling so the cross-variances between every two patches will contribute to the kernel.
📊 논문 시각자료 (Figures)
