Exact Computation with Infinitely Wide Neural Networks

Reading time: 23 minute
...

📝 Original Paper Info

- Title: On Exact Computation with an Infinitely Wide Neural Net
- ArXiv ID: 1904.11955
- Date: 2019-11-05
- Authors: Sanjeev Arora, Simon S. Du, Wei Hu, Zhiyuan Li, Ruslan Salakhutdinov, Ruosong Wang

📝 Abstract

How well does a classic deep net architecture like AlexNet or VGG19 classify on a standard dataset such as CIFAR-10 when its width --- namely, number of channels in convolutional layers, and number of nodes in fully-connected internal layers --- is allowed to increase to infinity? Such questions have come to the forefront in the quest to theoretically understand deep learning and its mysteries about optimization and generalization. They also connect deep learning to notions such as Gaussian processes and kernels. A recent paper [Jacot et al., 2018] introduced the Neural Tangent Kernel (NTK) which captures the behavior of fully-connected deep nets in the infinite width limit trained by gradient descent; this object was implicit in some other recent papers. An attraction of such ideas is that a pure kernel-based method is used to capture the power of a fully-trained deep net of infinite width. The current paper gives the first efficient exact algorithm for computing the extension of NTK to convolutional neural nets, which we call Convolutional NTK (CNTK), as well as an efficient GPU implementation of this algorithm. This results in a significant new benchmark for the performance of a pure kernel-based method on CIFAR-10, being $10\%$ higher than the methods reported in [Novak et al., 2019], and only $6\%$ lower than the performance of the corresponding finite deep net architecture (once batch normalization, etc. are turned off). Theoretically, we also give the first non-asymptotic proof showing that a fully-trained sufficiently wide net is indeed equivalent to the kernel regression predictor using NTK.

💡 Summary & Analysis

This paper explores how neural networks maintain their performance and structure during training by showing that the weight matrices do not significantly deviate from their initial values. The authors demonstrate that even with minor changes in weights, the network can still achieve rapid convergence. This finding is crucial for understanding how neural networks retain their effectiveness throughout the learning process, emphasizing the importance of proper initialization.

Key Summary

The paper demonstrates that during training, weight matrices in a neural network do not significantly deviate from their initial values and yet maintain high performance and rapid convergence.

Problem Statement

Neural networks need to converge quickly and improve performance, but large changes in weights can disrupt the original structure and performance of the model, leading to decreased generalization ability and inconsistent results.

Solution (Core Technology)

The paper proposes a method that ensures weight matrices do not deviate significantly from their initial values during training. Through analysis, it is proven that even with minor changes, the network still achieves rapid convergence.

Major Achievements

The study shows that neural networks can maintain high performance and rapid convergence without significant deviations from their initialized weights. This indicates that the network retains its structure while learning, approaching optimal solutions more closely.

Significance and Applications

This research highlights the importance of proper weight initialization in maintaining a neural network’s effectiveness during training. It provides insights into how to achieve stable and generalized learning processes, enhancing model performance and generalization ability.

📄 Full Paper Content (ArXiv Source)

By giving the first practical algorithm for computing CNTKs exactly, this paper allows investigation of the behavior of infinitely wide (hence infinitely over-parametrized) deep nets, which turns out to not be much worse than that of their finite counterparts. We also give a fully rigorous proof that a sufficiently wide net is approximately equivalent to the kernel regression predictor, thus yielding a powerful new off-the-shelf kernel. We leave it as an open problem to understand the behavior of infinitely wide nets with features such as Batch Normalization or Residual Layers. Of course, one can also hope that the analysis of infinite nets provides rigorous insight into finite ones.

In this section we derive CNTK for vanilla CNN. Given $`\vect{x} \in \mathbb{R}^{\nnw \times \nnh}`$ and $`(i,j) \in [\nnw]\times [\nnh]`$, we define

MATH
\phi_{ij}(\vect{x}) = [\vect{x}]_{i-(q-1)/2:i+(q-1)/2,j-(q-1)/2:j+(q-1)/2}
Click to expand and view more

i.e., this operator extracts the $`(i,j)`$-th patch. By this definition, we can rewrite the CNN definition:

Let $`\vect{x}^{(0)} =\vect{x} \in \mathbb{R}^{\nnw\times \nnh \times \nnc^{(0)}}`$ be the input image where $`\nnc^{(0)}`$ is the number of channels in the input image.

For $`h=1,\ldots,H`$, $`\beta = 1,\ldots,\nnc^{(h)}`$, the intermediate outputs are defined as

MATH
\begin{align*}
    \left[\tilde{\vect{x}}_{(\beta)}^{(h)}\right]_{ij} = 
    \sum_{\alpha=1}^{\nnc^{(h-1)}}\left\langle\mat{W}_{(\alpha),(\beta)}^{(h)},\vect\phi_{ij}\left({x}_{(\alpha)}^{(h-1)}\right)\right\rangle, \quad
%   \sum_{\alpha=1}^{\nnc^{(h-1)}} \mat{W}_{(\alpha),(\beta)}^{(h)} \conv \vect{x}_{(\alpha)}^{(h-1)} ,\quad
%   \end{align}
%%  where $\mat{W}_{(\alpha),(\beta)}^{(h)} \sim \gauss(\vect{0},\mat{I}) \in \mathbb{R}^{q^{(h)} \times q^{(h)}}$ 
%   and
%   \begin{align*}
    \vect{x}^{(h)}_{(\beta)} = \sqrt{\frac{c_{\sigma}}{\nnc^{(h)} \times q \times q}}\act{\tilde{\vect{x}}_{(\beta)}^{(h)}}
\end{align*}
Click to expand and view more

where each $`\mat{W}_{(\alpha),(\beta)}^{(h)} \in \mathbb{R}^{q \times q}`$ is a filter with Gaussian initialization.

The final output is defined as

MATH
\begin{align*}
    f(\params,\vect{x}) = \sum_{\alpha=1}^{\nnc^{(L)}} \left\langle \mat{W}_{(\alpha)}^{(L)},\vect{x}_{(\alpha)}^{(L)}\right\rangle
\end{align*}
Click to expand and view more

where $`\mat{W}_{(\alpha)}^{(L)} \in \mathbb{R}^{\nnw \times \nnh}`$ is a weight matrix with Gaussian initialization.

Expansion of CNTK

We expand $`\Theta^{(L)}(\vect{x},\vect{x}')`$ to show we can write it as the sum of $`(L+1)`$ terms with each term representing the inner product between the gradients with respect to the weight matrix of one layer. We first define an linear operator

MATH
\begin{align}
    \linop: \mathbb{R}^{\nnw\times\nnh\times\nnw\times\nnh} \rightarrow \mathbb{R}^{\nnw\times\nnh\times\nnw\times\nnh} \nonumber\\
    \left[\linop\left(\mat{M}\right)\right]_{k\ell,k'\ell'} = \frac{c_{\sigma}}{q^2}\tr\left(\left[\mat{M}\right]_{\indset_{k\ell,k'\ell'}}\right). \label{eqn:linop}
\end{align}
Click to expand and view more

This linear operator is induced from convolutional operation. And here use zero padding, namely when the subscription exceeds the range of $`[\nnw]\times [\nnh]\times[\nnw]\times [\nnh]`$, the value of the element is zero.

We also define $`\id\in \mathbb{R}^{\nnw\times\nnh\times\nnw\times\nnh}`$ as the identity tensor, namely $`\id_{i,j,i',j'}= \bm{1}\{i=i',j=j'\}.`$ And

MATH
\Sum{\mat{M}}=\sum_{(i,j,i',j') \in [\nnw] \times [\nnh]\times [\nnw]\times [\nnh] } \mat{M}_{i,j,i',j'}.
Click to expand and view more

The following property of $`\linop`$ is immediate by definition: $`\forall \mat{M},\mat{N} \in \mathbb{R}^{\nnw\times\nnh\times\nnw\times\nnh}`$, we have

MATH
\begin{equation}
\label{eq:conjugate2}
\Sum{\mat{M}\odot \linop(\mat{N})}  = \Sum{\linop(\mat{M})\odot \mat{N}}.
\end{equation}
Click to expand and view more

With this operator, we can expand CNTK as (for simplicity we drop on $`\vect{x}`$ and $`\vect{x}'`$)

MATH
\begin{align*}
&\Theta^{(L)}\\
= &\tr\left(\dot{\mat{K}}^{(L)}\odot\Theta^{(H-1)}+\mat{K}^{(L)}\right) \\
= & \tr\left(\mat{K}^{(L)}\right) + \tr\left(\dot{\mat{K}}^{(L)}\odot \linop\left(\mat{K}^{(H-1)}\right)\right) + \tr\left(\dot{\mat{K}}^{(L)}\odot\linop\left(
\dot{\mat{K}}^{(H-1)} \odot \mat{\Theta}^{(H-2)} 
\right)\right) \\
= & \ldots \\
= & \sum_{h=0}^{L} \tr\left(
\dot{\mat{K}}^{(L)} \odot \linop\left( \dot{\mat{K}}^{(H-1)} \linop \left(\cdots\dot{\mat{K}}^{(h+1)}\linop\left(\mat{K}^{(h)}\right)\cdots\right)\right) .
\right)
\end{align*}
Click to expand and view more

Here for $`h=H`$, the term is just $`\tr\left(\mat{K}^{(L)}\right)`$.

In the following, we will show

MATH
\begin{align*}
%\expect_{\params \sim \mathcal{W}}
    \left\langle \frac{\partial f(\params,\vect{x})}{\partial \mat{W}^{(h)}} , \frac{\partial f(\params,\vect{x}')}{\partial \mat{W}^{(h)}} \right\rangle 
    \approx 
    &\tr\left(
\dot{\mat{K}}^{(L)} \odot \linop\left( \dot{\mat{K}}^{(H-1)} \odot\linop \left(\cdots\dot{\mat{K}}^{(h)} \odot\linop\left(\mat{K}^{(h-1)}\right)\cdots\right)\right) 
\right)\\
=& \Sum{ \id\odot \dot{\mat{K}}^{(L)} \odot \linop\left( \dot{\mat{K}}^{(H-1)} \odot\linop \left(\cdots\dot{\mat{K}}^{(h)} \odot\linop\left(\mat{K}^{(h-1)}\right)\cdots\right)\right) }.
\end{align*}
Click to expand and view more

which could be rewritten as the following by Property [eq:conjugate2],

MATH
\begin{align*}
%\expect_{\params \sim \mathcal{W}}
    \left\langle \frac{\partial f(\params,\vect{x})}{\partial \mat{W}^{(h)}} , \frac{\partial f(\params,\vect{x}')}{\partial \mat{W}^{(h)}} \right\rangle \approx 
    \Sum{
\linop\left(\mat{K}^{(h-1)}\right) \odot  \dot{\mat{K}}^{(h)} \odot\linop \left(\dot{\mat{K}}^{(h+1)} \cdots \odot\linop\left(\id \odot \dot{\mat{K}}^{(L)}\right)\cdots\right)
}.
\end{align*}
Click to expand and view more

Derivation

We first compute the derivative of the prediction with respect to one single filter.

MATH
\begin{align*}
\frac{\partial f(\params,\vect{x})}{\partial \mat{W}_{(\alpha),(\beta)}^{(h)}} = &\left\langle\frac{\partial f(\params,\vect{x})}{\partial \vect{x}_{(\beta)}}, \frac{\partial \vect{x}_{(\beta)}^{(h)}}{\partial \vect{W}_{(\alpha),(\beta)}^{(h)}}\right\rangle \\
= & \sum_{(i,j) \in [\nnw] \times [\nnh]} \left\langle\frac{\partial f(\params,\vect{x})}{[\vect{x}_{(\beta)}^{(h)}]_{ij}},\frac{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}{\partial \mat{W}_{(\alpha),(\beta)}^{(h)}}\right\rangle \\
= & \sum_{(i,j) \in [\nnw] \times [\nnh]} \frac{\partial f(\params,\vect{x})}{[\vect{x}_{(\beta)}^{(h)}]_{ij}} \sqrt{\frac{c_{\sigma}}{\nnc^{(h)}q^2}}\reluder{\left[
\tilde{\vect{x}}_{(\beta)}^{(h)}
\right]_{ij}}\phi_{ij}(\vect{x}_{(\alpha)}^{(h-1)}).
\end{align*}
Click to expand and view more

With this expression, we proceed to we compute the inner product between gradients with respect to the $`h`$-th layer matrix

MATH
\begin{align}
&\sum_{\alpha=1}^{\nnc^{(h-1)}}\sum_{\beta=1}^{\nnc^{(h)}}\left\langle 
\frac{\partial f(\params,\vect{x})}{\partial \mat{W}_{(\alpha),(\beta)}^{(h)}}, \frac{\partial f(\params,\vect{x}')}{\partial \mat{W}_{(\alpha),(\beta)}^{(h)}} \right\rangle \label{eqn:grad_prod}\\
= &\sum_{(i,j,i',j') \in [\nnw] \times [\nnh]\times [\nnw]\times [\nnh] }\frac{c_{\sigma}}{\nnc^{(h)}q^2}\sum_{\beta=1}^{\nnc^{(h)}}
\left(
\frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}}
\right)
\left(
\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}}
\right) \nonumber\\
&\cdot\left(\sum_{\alpha=1}^{\nnc^{(h-1)}}\left\langle\phi_{ij}(\vect{x}_{(\alpha)}^{(h-1)}),\phi_{i'j'}(\vect{x}'^{(h-1)}_{(\alpha)})\right\rangle\right).\nonumber
\end{align}
Click to expand and view more

Similar to our derivation to NTK, we can use the following approximation

MATH
\begin{align*}
\left(\sum_{\alpha=1}^{\nnc^{(h-1)}}\left\langle\phi_{ij}(\vect{x}_{(\alpha)}^{(h-1)}),\phi_{i'j'}(\vect{x}'^{(h-1)}_{(\alpha)})\right\rangle\right) \approx
\tr\left(\left[\mat{K}^{(h-1)}\right]_{\indset_{ij,i'j'}}\right) = \linop\left(\mat{K}^{(h-1)}\right).
\end{align*}
Click to expand and view more

Thus it remains to show that $`\forall (i,j,i',j') \in [\nnw] \times [\nnh]\times [\nnw]\times [\nnh]`$,

MATH
\begin{align*}
&\sum_{\beta=1}^{\nnc^{(h)}}\frac{c_{\sigma}}{\nnc^{(h)}q^2}
\left(
\frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}}
\right)
\left(
\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}}
\right)\\
\approx& 
\left[\linop \left(\dot{\mat{K}}^{(h+1)} \cdots \odot\linop\left(\id \odot \dot{\mat{K}}^{(L)}\right)\cdots\right)\odot \dot{\mat{K}}^{(h)} \right]_{i,j,i',j'}
\end{align*}
Click to expand and view more

The key step of this derivation is the following approximation (Equation [eq:gaussian_conditioning2]), which assumes for each $`(i,j,i',j')`$, $`\frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}}`$ and $`\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}}`$ are independent. This is used and made rigorous for ReLU activation and fully-connected networks in the proof of Theorem [thm:ntk_init]. gave a rigorous statement of this approximation in an asymptotic way for CNNs.

MATH
\begin{equation}
\begin{split}\label{eq:gaussian_conditioning2}
&\frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}} \left(
\frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}}
\right)
\left(
\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}}
\right) \\
\approx& 
\left(\frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}} 
\frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}}
\right)
\left(\frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}}
\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}}
\right) 
\end{split}
\end{equation}
Click to expand and view more

Note that

MATH
\begin{align*}
    \frac{c_{\sigma}}{\nnc^{(h)}q^2}\sum_{\beta=1}^{\nnc^{(h)}}\reluder{\left[\tilde{\vect{x}}^{(h)}_{(\beta)}\right]_{ij}}\reluder{\left[\tilde{\vect{x}}'^{(h)}_{(\beta)}\right]_{i'j'}} \approx \left[\dot{\mat{K}}^{(h)}\left(\vect{x},\vect{x}'\right)\right]_{ij,i'j'},
\end{align*}
Click to expand and view more

the derivation is complete once we show

MATH
\begin{equation}
\label{eq:defi_G}\mat{G}^{(h)}(\vect{x},\vect{x'},\vect{\theta}):=\frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}} 
\frac{\partial f(\params,\vect{x})}{\partial \vect{x}_{(\beta)}^{(h)}}\otimes \frac{\partial f(\params,\vect{x}')}{\partial \vect{x}'^{(h)}_{(\beta)}}
\approx 
\linop \left(\dot{\mat{K}}^{(h+1)} \cdots \odot\linop\left(\id \odot \dot{\mat{K}}^{(L)}\right)\cdots\right).
\end{equation}
Click to expand and view more

Now, we tackle the term $`\left( \frac{\partial f(\params,\vect{x})}{\partial [\vect{x}_{(\beta)}^{(h)}]_{ij}}\cdot\frac{\partial f(\params,\vect{x}')}{\partial [\vect{x}'^{(h)}_{(\beta)}]_{i'j'}} \right)`$. Notice that

MATH
\begin{align*}
    \frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\beta)}^{(h)}\right]_{ij}} = & \sum_{(k,\ell) \in [\nnw]\times [\nnh] } 
    \frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\gamma
            )}^{(h+1)}\right]_{k\ell}} 
    \frac{\partial \left[\vect{x}_{(\gamma)}^{(h+1)}\right]_{k\ell}}{\partial \left[\vect{x}_{(\beta)}^{(h)}\right]_{ij}}.
\end{align*}
Click to expand and view more

and for $`\gamma \in [\nnc^{(h+1)}]`$ and $`(k,\ell) \in [\nnw] \times [\nnh]`$

MATH
\begin{align*}
\frac{\partial \left[\vect{x}_{(\gamma)}^{(h+1)}\right]_{k\ell}}{\partial \left[\vect{x}_{(\beta)}^{(h)}\right]_{ij}} = \begin{cases}
\sqrt{\frac{c_{\sigma}}{\nnc^{(h+1)q^2}}}\reluder{\left[\tilde{\vect{x}}_{(\gamma)}^{(h+1)}\right]_{k\ell}}  \left[\mat{W}_{(\beta),(\gamma)}^{(h+1)}\right]_{i-k+q-1,j-\ell+q-1}&\text{ if } (i,j) \in \indset_{k\ell} \\
0 &\text{ otherwise }
\end{cases}.
\end{align*}
Click to expand and view more

We then have

MATH
\begin{align}
&\left[\mat{G}^{(h)}(\vect{x},\vect{x'},\vect{\theta})\right]_{ij,i'j'} = \frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}} \frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\beta)}^{(h)}\right]_{ij}} \frac{\partial f(\params,\vect{x}')}{\partial \left[\vect{x'}^{(h)}_{(\beta)}\right]_{i'j'}} \nonumber \\ 
=&\sum_{k,\ell,k',\ell'} \frac{c_{\sigma}}{\nnc^{(h+1)}q^2}\sum_{\gamma=1}^{\nnc^{(h+1)}}\left(
    \frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\gamma
            )}^{(h+1)}\right]_{k\ell}} 
    \frac{\partial f(\params,\vect{x}')}{\partial \left[\vect{x}_{(\gamma
            )}^{('h+1)}\right]_{k'\ell'}}\right) \left(\reluder{\left[\tilde{\vect{x}}^{(h+1)}_{(\gamma)}\right]_{k\ell}}\reluder{\left[\tilde{\vect{x}'}^{(h+1)}_{(\gamma)}\right]_{k'\ell'}}\right)\nonumber\\
    & \cdot \frac{1}{\nnc^{(h)}} \sum_{\beta=1}^{\nnc^{(h)}} \indict\left\{(i,j,i',j')\in \indset_{k\ell,k'\ell'}\right\} \left[\mat{W}_{(\beta),(\gamma)}^{(h+1)}\right]_{i-k+q-1,j-\ell+q-1}\left[\mat{W}_{(\beta),(\gamma)}^{(h+1)}\right]_{i'-k'+q-1,j'-\ell'+q-1}\nonumber\\
\approx & \sum_{k,\ell,k',\ell'} \frac{c_{\sigma}}{\nnc^{(h+1)}q^2}\sum_{\gamma=1}^{\nnc^{(h+1)}}\left(
\frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\gamma
        )}^{(h+1)}\right]_{k\ell}} 
\frac{\partial f(\params,\vect{x}')}{\partial \left[\vect{x}_{(\gamma
        )}^{('h+1)}\right]_{k'\ell'}}\right) \left(\reluder{\left[\tilde{\vect{x}}^{(h+1)}_{(\gamma)}\right]_{k\ell}}\reluder{\left[\tilde{\vect{x}'}^{(h+1)}_{(\gamma)}\right]_{k'\ell'}}\right)\nonumber\\
& \cdot \indict\left\{(i,j,i',j')\in \indset_{k\ell,k'\ell'}, i-k=i'-k',j-\ell=j'-\ell'\right\}\nonumber\\
\approx & \sum_{k,\ell,k',\ell'} \left(\frac{1}{\nnc^{(h+1)}}\sum_{\gamma=1}^{\nnc^{(h+1)}}
\frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\gamma
        )}^{(h+1)}\right]_{k\ell}} 
\frac{\partial f(\params,\vect{x}')}{\partial \left[\vect{x}_{(\gamma
        )}^{('h+1)}\right]_{k'\ell'}}\right) \left(\frac{c_\sigma}{q^2\nnc^{(h+1)}}\sum_{\gamma=1}^{{\nnc^{(h+1)}}}\reluder{\left[\tilde{\vect{x}}^{(h+1)}_{(\gamma)}\right]_{k\ell}}\reluder{\left[\tilde{\vect{x}'}^{(h+1)}_{(\gamma)}\right]_{k'\ell'}}\right)\nonumber\\
& \cdot \indict\left\{(i,j,i',j')\in \indset_{k\ell,k'\ell'}, i-k=i'-k',j-\ell=j'-\ell'\right\} \nonumber\\
\approx & \sum_{k,\ell,k',\ell'} \left(\frac{1}{\nnc^{(h+1)}}\sum_{\gamma=1}^{\nnc^{(h+1)}}
\frac{\partial f(\params,\vect{x})}{\partial \left[\vect{x}_{(\gamma
        )}^{(h+1)}\right]_{k\ell}} 
\frac{\partial f(\params,\vect{x}')}{\partial \left[\vect{x}_{(\gamma
        )}^{('h+1)}\right]_{k'\ell'}}\right) \left[\dot{\mat{K}}^{(h+1)}\left(\vect{x},\vect{x}'\right)\right]_{\ell k,\ell'k'}\nonumber\\
& \cdot \indict\left\{(i,j,i',j')\in \indset_{k\ell,k'\ell'}, i-k=i'-k',j-\ell=j'-\ell'\right\}\nonumber\\
\approx & \tr\left( \left[ \mat{G}^{(h+1)}(\vect{x},\vect{x'},\vect{\theta}) \odot\dot{\mat{K}}^{(h+1)}\left(\vect{x},\vect{x}'\right)\right]_{D_{ij,i'j'}}\right)
 \label{eqn:grad_fxh}
\end{align}
Click to expand and view more

where the first approximation is due to our initialization of $`\mat{W}^{(h+1)}`$. In other words, we’ve shown

MATH
\begin{equation}
\label{eq:recursion_G} \mat{G}^{(h)}(\vect{x},\vect{x'},\vect{\theta}) = \linop\left(\mat{G}^{(h+1)}(\vect{x},\vect{x'},\vect{\theta})\odot \dot{\mat{K}}^{(h+1)}\left(\vect{x},\vect{x}'\right)\right).
\end{equation}
Click to expand and view more

Since we use a fully-connected weight matrix as the last layer, we have $`\mat{G}^{(L)}(\vect{x},\vect{x'},\vect{\theta}) \approx \id`$.

Thus by induction with Equation [eq:recursion_G], we have derived Equation [eq:defi_G], which completes the derivation of CNTK.

For the derivation of CNTK-GAP, the only difference is due to the global average pooling layer(GAP), $`\mat{G}^{(L)}(\vect{x},\vect{x'},\vect{\theta}) \approx \frac{1}{\nnh^2\nnw^2} \bm{1}\otimes\bm{1}`$, where $`\bm{1}\otimes\bm{1}\in \mathbb{R}^{\nnw\times\nnh\times\nnw\times\nnh}`$ is the all one tensor.

How well does a classic deep net architecture like AlexNet or VGG19 classify on a standard dataset such as CIFAR-10 when its “width”— namely, number of channels in convolutional layers, and number of nodes in fully-connected internal layers — is allowed to increase to infinity? Such questions have come to the forefront in the quest to theoretically understand deep learning and its mysteries about optimization and generalization. They also connect deep learning to notions such as Gaussian processes and kernels. A recent paper introduced the Neural Tangent Kernel (NTK) which captures the behavior of fully-connected deep nets in the infinite width limit trained by gradient descent; this object was implicit in some other recent papers. An attraction of such ideas is that a pure kernel-based method is used to capture the power of a fully-trained deep net of infinite width.

The current paper gives the first efficient exact algorithm for computing the extension of NTK to convolutional neural nets, which we call Convolutional NTK (CNTK), as well as an efficient GPU implementation of this algorithm. This results in a significant new benchmark for performance of a pure kernel-based method on CIFAR-10, being $`10\%`$ higher than the methods reported in , and only $`6\%`$ lower than the performance of the corresponding finite deep net architecture (once batch normalization etc. are turned off). Theoretically, we also give the first non-asymptotic proof showing that a fully-trained sufficiently wide net is indeed equivalent to the kernel regression predictor using NTK.

In this section we define CNNs studied in this paper and derive the corresponding CNTKs.

Full-connected NN

We first define fully-connected neural network and its corresponding gradient.

MATH
\begin{align}
f(\params,\vect{x}) = \vect{a}^\top \relu{\mat{W}_H\relu{\mat{W}_{H-1}\ldots\relu{\mat{W}_1\vect{x}}}} \label{eqn:fc}
\end{align}
Click to expand and view more

We can write the gradient in a compact form

MATH
\begin{align}
\frac{\partial f(\params,\vect{x})}{\partial \mat{W}_{h}} = \back_{h+1}(x)^\top  \left(\vect{x}^{(h-1)}\right)^\top
\end{align}
Click to expand and view more

where

MATH
\begin{align}
\back_{h}(\vect{x}) = \vect{a}^\top \mat{D}_{H}(\vect{x})\mat{W}_{H}\cdots \mat{D}_h(\vect{x})\mat{W}_{h} \in \mathbb{R}^{1 \times m} \\
\mat{D}_h = \diag\left(\left[\dot{\sigma}\left(\mat{W}_h\vect{x}^{(h-1)}\right)\right]\right) \in \mathbb{R}^{m \times m}
\end{align}
Click to expand and view more

and

MATH
\begin{align}
\vect{x}^{(h)} = \mat{W}^{(h)}\relu{\mat{W}^{(h)}\cdot\relu{\mat{W}^{(1)}\vect{x}}}
\end{align}
Click to expand and view more

CNN

We consider the follow two dimensional convolutional neural networks.

  • Input $`\vect{x}^{(0)} \in \mathbb{R}^{\nnw\times \nnh \times \nnc^{(0)}}`$ where $`\nnw`$ is the width, $`\nnh`$ is the height and $`\nnc`$ is the initial channel.

  • For $`\ell=1,\ldots,L`$, $`\beta = 1,\ldots,\nnc^{(\ell)}`$, the intermediate inputs are defined as

    MATH
    \begin{align}
    \tilde{\vect{x}}_{(\beta)}^{(\ell)} = \sum_{\alpha=1}^{\nnc^{(\ell-1)}} \mat{W}_{(\alpha),(\beta)}^{(\ell)} \conv \vect{x}_{(\beta)}^{(\ell-1)} \in \mathbb{R}^{\nnw \times \nnh}
    \end{align}
    Click to expand and view more

    where $`\mat{W}_{(\alpha),(\beta)}^{(\ell)} \sim \gauss(\vect{0},\mat{I}) \in \mathbb{R}^{q^{(\ell)} \times q^{(\ell)}}`$ and

    MATH
    \begin{align}
    \vect{x}^{(\ell)}_{(\beta)} = \sqrt{\frac{c_{\sigma}}{\nnc^{(\ell) \times q^{(\ell)} \times q^{(\ell)}}}}\act{\tilde{\vect{x}}_{(\beta)}^{(\ell)}}.
    \end{align}
    Click to expand and view more
  • The final output is defined as

    MATH
    \begin{align}
    f(\vect{x}^{(0)}) = \sum_{\beta=1}^{\nnc^{(L)}} \left\langle \mat{W}_{(\beta)}^{(L)},\vect{x}_{(\beta)}^{(L)}\right\rangle.
    \end{align}
    Click to expand and view more

The convolution operator is defined as for $`\mat{W} \in \mathbb{R}^{q\times q}`$ and $`\vect{x} \in \mathbb{R}^{\nnw \times \nnh}`$

MATH
\begin{align}
[\vect{w}\conv \vect{x}]_{ij} = \sum_{(a,b) = (-\frac{q-1}{2}, -\frac{q-1}{2})}^{(-\frac{q+1}{2}, -\frac{q+1}{2}))}\vect{x}_{i+1,j+b} \cdot \mat{W}_{a+\frac{q+1}{2},b+\frac{q+1}{2}}.
\end{align}
Click to expand and view more

In this subsection we prove the following lemma.

Fix $`\omega \le\poly(1/L,1/n,1/\log(1/\delta), \lambda_0)`$. Suppose we set $`m \ge \poly(1/\omega)`$ and $`\kappa \le 1`$. Then with probability at least $`1-\delta`$ over random initialization, we have for all $`t \ge 0`$, for any $`(\vect{x},\vect{x}') \in \left\{\vect{x}_1,\ldots,\vect{x}_n,\vect{x}_{te}\right\} \times \left\{\vect{x}_1,\ldots,\vect{x}_n,\vect{x}_{te}\right\}`$

MATH
\begin{equation*}
    \abs{
\kernel_{t}\left(\vect{x},\vect{x}'\right)- \kernel_{0}\left(\vect{x},\vect{x}'\right)
    } \le \omega
\end{equation*}
Click to expand and view more

Recall for any fixed $`\vect{x}`$ and $`\vect{x}'`$, Theorem [thm:ntk_init] shows $`\abs{\kernel_{0}(\vect{x},\vect{x}')-\kernel_{ntk}(\vect{x},\vect{x}')} \le \epsilon`$ if $`m`$ is large enough. The next lemma shows we can reduce the problem of bounding the perturbation on the kernel value to the perturbation on the gradient.

If $`\norm{\frac{\partial f(\params(t),\vect{x})}{\partial \params} - \frac{\partial f(\params(0),\vect{x})}{\partial \params}} \le \epsilon`$ and $`\norm{\frac{\partial f(\params(t),\vect{x}')}{\partial \params} - \frac{\partial f(\params(0),\vect{x}')}{\partial \params}} \le \epsilon`$, we have

MATH
\begin{align*}
        \abs{\kernel_{t}(\vect{x},\vect{x}') - \kernel_{0}(\vect{x},\vect{x}')} \le O\left(\epsilon\right)
\end{align*}
Click to expand and view more

Proof. By the proof of Theorem [thm:ntk_init], we know $`\norm{\frac{\partial f(\params(0),\vect{x})}{\partial \params}}_2 = O\left(1\right)`$. Then we can just use triangle inequality. ◻

Now we proceed to analyze the perturbation on the gradient. Note we can focus on the perturbation on a single sample $`\vect{x}`$ because we can later take a union bound. Therefore, in the rest of this section, we drop the dependency on a specific sample. We use the following notations in this section. Recall $`\mat{W}^{(1)},\ldots,\mat{W}^{(L+1)} \sim \gauss\left(\mat{0},\mat{I}\right)`$ and we denote $`\diff \mat{W}^{(1)},\ldots,\diff \mat{W}^{(L+1)}`$ the perturbation matrices. We let $`\widetilde{\mat{W}}^{(h)} = \mat{W}^{(h)} + \diff\mat{W}^{(h)}`$. We let $`\tilde{\vect{g}}^{(0)} = \vect{g}^{(0)}=\vect{x}`$ and for $`h=1,\ldots,L`$ we define

MATH
\begin{align*}
\vect{z}^{(h)} = &\sqrt{\frac{2}{m}}\mat{W}^{(h)} 
\vect{g}^{(h-1)}, ~~~ \vect{g}^{(h)} = \relu{\vect{z}^{(h)}},\\
\tilde{\vect{z}}^{(h)} = &\sqrt{\frac{2}{m}}\widetilde{\mat{W}}^{(h)}
\tilde{\vect{g}}^{(h-1)}, ~~~ \tilde{\vect{g}}^{(h)} = \relu{\tilde{\vect{z}}^{(h)}}.
\end{align*}
Click to expand and view more

For $`h=1,\ldots,L`$, $`i=1,\ldots,m`$, we denote

MATH
\begin{align*}
[\mat{D}^{(h)}]_{ii} =
 &\indict\left\{\left[\mat{W}^{(h)}\right]_{i,:}\vect{g}^{(h-1)}\ge 0\right\}\\
[\widetilde{\mat{D}}^{(h)}]_{ii} =
 &\indict\left\{\left[\widetilde{\mat{W}}^{(h)}\right]_{i,:}\tilde{\vect{g}}^{(h-1)}\ge 0\right\}.
\end{align*}
Click to expand and view more

Note $`\vect{z}^{(h)} = \sqrt{\frac{2}{m}}\vect{f}^{(h)}`$. Here we use $`\vect{z}^{(h)}`$ instead of $`\vect{f}^{(h)}`$ for the ease of presentation.

For convenience, we also define

MATH
\begin{align*}
\diff \mat{D}^{(h)} = \widetilde{\mat{D}}^{(h)} - \mat{D}^{(h)}.
\end{align*}
Click to expand and view more

Recall the gradient to $`\mat{W}^{(h)}`$ is:

MATH
\begin{align*}
\frac{\partial f(\params,\vect{x})}{\partial \mat{W}^{(h)}} = \vect{b}^{(h)} \left(\vect{g}^{(h-1)}\right)^\top
\end{align*}
Click to expand and view more

Similarly, we have

MATH
\begin{align*}
\frac{\partial f(\params,\vect{x})}{\partial \widetilde{\mat{W}}^{(h)}} = \tilde{\vect{b}}^{(h)} \left(\tilde{\vect{g}}^{(h-1)}\right)^\top
\end{align*}
Click to expand and view more

where

MATH
\begin{align*}
\tilde{\vect{b}}^{(h)} = \begin{cases}
1 &\text{ if } h = L+1\\
\sqrt{\frac{2}{m}}\widetilde{\mat{D}}^{(h)}\left(\widetilde{\mat{W}}^{(h+1)}\right)^\top\tilde{\vect{b}}^{(h+1)} &\text{ Otherwise}
\end{cases} .
\end{align*}
Click to expand and view more

This gradient formula allows us to bound the perturbation on $`\diff \vect{g}^{(h)}\triangleq\tilde{\vect{g}}^{(h)} - \vect{g}^{(h)}`$ and $`\diff \vect{b}^{(h)}\triangleq\tilde{\vect{b}}^{(h)}-\vect{b}^{(h)}`$ separately. The following lemmas adapted from show with high probability over the initialization, bounding the perturbation on $`\diff \vect{g}^{(h)}`$ and $`\diff \vect{b}^{(h)}`$ can be reduced to bounding the perturbation on weight matrices.

Suppose

MATH
\omega\le \poly\left(1/n,\lambda_0,1/L,1/\log(m),\epsilon,1/\log(1/\delta)\right).
Click to expand and view more

Then with probability at least $`1-\delta`$ over random initialization, if $`\norm{\diff \mat{W}^{(h)}}_2 \le \sqrt{m}\omega`$ for all $`h=1,\ldots,L`$, we have $`\norm{\diff \vect{g}^{(h)}}_2 = O(\omega L^{5/2}\sqrt{\log m})`$ for all $`h=1,\ldots,L`$.

While did not consider the perturbation on $`\mat{W}^{(1)}`$, by scrutinizing their proof, it is easy to see that the perturbation bounds still hold even if there is a small perturbation on $`\mat{W}^{(1)}`$.

The next lemma bounds the backward vector, adapted from

Suppose

MATH
\omega\le \poly\left(1/n,\lambda_0,1/L,1/\log(m),\epsilon,1/\log(1/\delta)\right).
Click to expand and view more

Then with probability at least $`1-\delta`$ over random initialization,if $`\norm{\diff \mat{W}^{(h)}}_2 \le \sqrt{m}\omega`$ for all $`h=1,\ldots,L+1`$, we have for all $`h=1,\ldots,L+1`$, $`\norm{\tilde{\vect{b}}^{(h)}-\vect{b}^{(h)}}_2 = O\left(\omega^{1/3}L^2\sqrt{\log m}\right)`$.

While did not consider the perturbation on $`\mat{W}^{(L+1)}`$, by scrutinizing their proof, it is easy to see that the perturbation bounds still hold even if there is a small perturbation on $`\mat{W}^{(L+1)}`$.

Combing these two lemmas and the result for the initialization (Theorem [thm:ntk_init]), we have the following “gradient-Lipschitz" lemma.

Suppose $`\omega \le \poly\left(1/n,\lambda_0,1/L,1/\log(m),\epsilon,1/\log(1/\delta)\right).`$ Then with probability at least $`1-\delta`$ over random initialization, if $`\norm{\diff \mat{W}^{(h)}}_2 \le \sqrt{m}\omega`$ for all $`h=1,\ldots,L+1`$, we have for all $`h=1,\ldots,L+1`$:

MATH
\begin{align*}
    \norm{\tilde{\vect{b}}^{(h)}\left(\tilde{\vect{g}}^{(h-1)}\right)^\top-\vect{b}^{(h)}\left(\vect{g}^{(h-1)}\right)^\top}_F = O\left(\omega^{1/3}L^{5/2}\sqrt{\log m}\right)
\end{align*}
Click to expand and view more

Proof. We use the triangle inequality to bound the perturbation

MATH
\begin{align*}
    &\norm{\tilde{\vect{b}}^{(h)}\left(\tilde{\vect{g}}^{(h-1)}\right)^\top-\vect{b}^{(h)}\left(\vect{g}^{(h-1)}\right)^\top}_F \\
\le&  \norm{\tilde{\vect{b}}^{(h)}\left(\tilde{\vect{g}}^{(h-1)}\right)^\top-\vect{b}^{(h)}\left(\tilde{\vect{g}}^{(h-1)}\right)^\top}_F  + \norm{\vect{b}^{(h)}\left(\tilde{\vect{g}}^{(h-1)}\right)^\top-\vect{b}^{(h)}\left(\vect{g}^{(h-1)}\right)^\top}_F \\
\le & \norm{
\diff \vect{b}^{(h)}\left(\vect{g}^{(h-1)}+\diff \vect{g}^{(h-1)}\right)^\top
}_F + \norm{\vect{b}^{(h)}\left(\diff \vect{g}^{(h-1)}\right)^\top}_F\\
= & O\left(\omega^{1/3}L^{5/2}\sqrt{\log m}\right).
\end{align*}
Click to expand and view more

 ◻

The following lemma shows for given weight matrix, if we have linear convergence and other weight matrices are only perturbed by a little, then the given matrix is only perturbed by a little as well.

Fix $`h \in [L+1]`$ and a sufficiently small $`\omega \le \poly\left(1/n,\lambda_0,1/L,1/\log(m),\epsilon,1/\log(1/\delta),\kappa\right).`$ Suppose for all $`t \ge 0`$, $`\norm{\vect{u}_{nn}(t)-\vect{y}}_2 \le \exp\left(-\frac12\kappa^2\lambda_0t\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2`$ and $`\norm{\mat{W}^{(h')}(t)-\mat{W}^{(h')}(0)}_F \le \omega \sqrt{m}`$ for $`h'\neq h`$. Then if $`m \ge \poly\left(1/\omega\right)`$ we have with probability at least $`1-\delta`$ over random initialization, for all $`t \ge 0`$

MATH
\begin{align*}
\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F = O \left( \frac{\sqrt{n}}{\lambda_0}\right)\le\omega\sqrt{m}.
\end{align*}
Click to expand and view more

Proof. We let $`C,C_0, C_1, C_2, C_3 > 0`$ be some absolute constants.

MATH
\begin{align*}
&\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F\\
 = &\norm{\int_{0}^{t}\frac{d \mat{W}^{(h)}(\tau)}{d\tau} d\tau}_F \\
=& \norm{\int_{0}^{t}\frac{\partial L(\params(\tau))}{\partial \mat{W}^{(h)}(\tau)} d\tau}_F \\
= & \norm{\int_{0}^{t}\frac{1}{n}\sum_{i=1}^n \left(u_i(\tau)-y_i\right) \frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}} d\tau}_F \\
\le & \frac{1}{n}\max_{0\le \tau\le t}\sum_{i=1}^n\norm{ \frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F\int_{0}^{t} \norm{\vect{u}_{nn}(\tau)-\vect{y}}_2 d\tau\\
\le & \frac{1}{n}\max_{0\le \tau\le t}\sum_{i=1}^n\norm{ \frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F\int_{0}^{t} \exp\left(-\kappa^2\lambda_0\tau\right) d\tau\norm{\vect{u}_{nn}(0)-\vect{y}}_2 \\
\le & \frac{C_0}{\sqrt{n}\lambda_0}\max_{0\le \tau\le t}\sum_{i=1}^n\norm{ \frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F &\\
\le &\frac{C_0}{\sqrt{n}\kappa^2\lambda_0}\max_{0\le \tau\le t}\sum_{i=1}^n\left(\norm{ \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F + \norm{\frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}- \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F\right)\\
\le &\frac{C_1}{\sqrt{n}\lambda_0}\max_{0\le \tau\le t}\sum_{i=1}^n\left(\norm{ \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F + \norm{\frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}- \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F\right) \\
\le &\frac{C_2\sqrt{n}}{\lambda_0}+\frac{C_1\sqrt{n}}{\lambda_0}\max_{0\le \tau\le t}\left( \norm{\frac{\partial f_{nn}(\params(\tau),\vect{x}_i)}{\partial \mat{W}^{(h)}}- \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F\right).
\end{align*}
Click to expand and view more

The last step we used $`\norm{ \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F = O(1)`$. Suppose there exists $`t`$ such that $`\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F > \omega\sqrt{m}.`$ Denote

MATH
t_0 = \argmin_{t\ge 0}\left\{\norm{\mat{
W
}^{(h)}(t)-\mat{W}^{(h)}(0)}_F > \omega\sqrt{m}.
\right\}.
Click to expand and view more

For any $`t < t_0`$, we know for all $`h' \in [L+1]`$, $`\norm{\mat{W}^{(h')}(t)-\mat{W}^{(h')}(0)}_2 \le \omega\sqrt{m}.`$ Therefore, by Lemma [lem:grad_lip], we know

MATH
\begin{align*}
\norm{\frac{\partial f_{nn}(\params(t),\vect{x}_i)}{\partial \mat{W}^{(h)}}- \frac{\partial f_{nn}(\params(0),\vect{x}_i)}{\partial \mat{W}^{(h)}}}_F = C \omega^{1/3}L^{5/2}.
\end{align*}
Click to expand and view more

Therefore, using the fact that $`\omega`$ is sufficiently small we can bound

MATH
\begin{align*}
\norm{\mat{W}^{(h)}(t_0)-\mat{W}^{(h)}(0)}_F \le \frac{C_3\sqrt{n}}{\lambda_0}.
\end{align*}
Click to expand and view more

Since we also know $`m`$ is sufficiently large to make $`\omega\sqrt{m} > \frac{C_3\sqrt{n}}{\lambda_0}`$, we have a contradiction. ◻

The next lemma shows if all weight matrices only have small perturbation, then we still have linear convergence.

Suppose $`\omega = \poly\left(1/n,\lambda_0,1/L,1/\log(m),\epsilon,1/\log(1/\delta),\kappa\right).`$ Suppose for all $`t \ge 0`$ $`\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F \le \omega \sqrt{m}`$ for $`h \in [L+1]`$. Then if $`m = \poly\left(1/\omega\right)`$, we have with probability at least $`1-\delta`$ over random initialization, for all $`t \ge 0`$

MATH
\begin{align*}
\norm{\vect{u}_{nn}(t)-\vect{y}}_2 \le \exp\left(-\frac{1}{2}\kappa^2\lambda_0t\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2.
\end{align*}
Click to expand and view more

Proof. Under this assumption and the result of initialization, we know for all $`t \ge 0`$, $`\lambda_{\min}\left(\trainker(t)\right) \ge \frac{1}{2}\lambda_0`$. This in turn directly imply the linear convergence result we want. ◻

Lastly, with these lemmas at hand, using an argument similar to , we can show during training, weight matrices do not move by much.

Let $`\omega \le \poly(\eps,L,\lambda_0,1/\log(m),1/\log(1/\delta),\kappa, 1/n)`$. If $`m \ge \poly(1/\omega)`$, then with probability at least $`1-\delta`$ over random initialization, we have for all $`t \ge 0`$, for all $`h \in [L+1]`$ we have

MATH
\begin{align*}
\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F \le \omega\sqrt{m}
\end{align*}
Click to expand and view more

and

MATH
\begin{align*}
\norm{\vect{u}_{nn}(t)-\vect{y}}_2 \le \exp\left(-\frac{1}{2}\kappa^2\lambda_0t\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2.
\end{align*}
Click to expand and view more

Proof. Let

MATH
\begin{align*}
t_0 = \argmin_{t}\left\{\exists h \in [L+1], \norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F > \omega\sqrt{m} \right.\\
\left.\text{ or }\norm{\vect{u}_{nn}(t)-\vect{y}}_2 > \exp\left(-\frac{1}{2}\kappa^2\lambda_0t\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2\right\}.
\end{align*}
Click to expand and view more

We analyze case by case. Suppose at time $`t_0`$, $`\norm{\mat{W}^{(h)}(t_0)-\mat{W}^{(h)}(0)}_F > \omega\sqrt{m}`$. By Lemma [lem:lin_conv_w_perb_small], we know there exists some $`0\le t_1 < t_0`$ such that either there exists $`h' \neq h`$ such that

MATH
\norm{\mat{W}^{(h')}(t_1)-\mat{W}^{(h')}(0)}_F > \omega\sqrt{m}
Click to expand and view more

or

MATH
\norm{\vect{u}_{nn}(t_1)-\vect{y}}_2 > \exp\left(-\frac{1}{2}\kappa^2\lambda_0t_1\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2.
Click to expand and view more

However, this violates the minimality of $`t_0`$. For the other case, if

MATH
\norm{\vect{u}_{nn}(t_0)-\vect{y}}_2 > \exp\left(-\frac{1}{2}\kappa^2\lambda_0t_0\right)\norm{\vect{u}_{nn}(0)-\vect{y}}_2,
Click to expand and view more

By Lemma [lem:w_perb_small_lin_conv], we know there exists $`t_1 < t_0`$ such that there exists $`h \in [L+1]`$,

MATH
\norm{\mat{W}^{(h)}(t_1)-\mat{W}^{(h)}(0)}_F > \omega\sqrt{m}.
Click to expand and view more

However, again this violates the minimality of $`t_0`$. ◻

Now we can finish the proof of Lemma [lem:ker_perb_train].

Proof of Lemma [lem:ker_perb_train]. By Lemma [lem:cont_induction], we know for $`t \rightarrow \infty`$, $`\norm{\mat{W}^{(h)}(t)-\mat{W}^{(h)}(0)}_F \le O\left(\omega\sqrt{m}\right)`$ for if $`\omega`$ is sufficiently. Applying Lemma [lem:grad_lip], we know we only have a small perturbation on the gradient. Applying Lemma [lem:grad_to_kernel], we know we only have small perturbation on kernel values. ◻

From a Gaussian process (GP) viewpoint, the correspondence between infinite neural networks and kernel machines was first noted by . Follow-up work extended this correspondence to more general shallow neural networks . More recently, this was extended to deep and convolutional neural networks  and a variety of other architectures . However, these kernels, as we discussed in Section 1, represent weakly-trained nets, instead of fully-trained nets.

Beyond GPs, the connection between neural networks and kernels is also studied in the compositional kernel literature. derived a closed-form kernel formula for rectified polynomial activations, which include ReLU as a special case. proposed a general framework to transform a neural network to a compositional kernel and later showed for sufficiently wide neural networks, stochastic gradient descent can learn functions that lie in the corresponding reproducing kernel Hilbert space. However, the kernels studied in these works still correspond to weakly-trained neural networks.

This paper is inspired by a line of recent work on over-parameterized neural networks . These papers established that for (convolutional) neural networks with large but finite width, (stochastic) gradient descent can achieve zero training error. A key component in these papers is showing that the weight matrix at each layer is close to its initialization. This observation implies that the kernel defined in Equation [eqn:ntk] is still close to its initialization. explicitly used this observation to derive generalization bounds for two-layer over-parameterized neural networks. argued that these results in the kernel regime may be too simple to be able to explain the success of deep learning, while on the other hand, out results show that CNTK is at least able to perform well on tasks like CIFAR-10 classification. Also see the survey  for recent advance in deep learning theory.

derived the exact same kernel from kernel gradient descent. They showed that if the number of neurons per layer goes to infinity in a sequential order, then the kernel remains unchanged for a finite training time. They termed the derived kernel Neural Tangent Kernel (NTK). We follow the same naming convention and name its convolutional extension Convolutional Neural Tangent Kernel (CNTK). Later, derived a formula of CNTK as well as a mechanistic way to derive NTK for different architectures. Comparing with , our CNTK formula has a more explicit convolutional structure and results in an efficient GPU-friendly computation method. Recently, tried to empirically verify the theory in by studying the linearization of neural nets. They observed that in the first few iterations, the linearization is close to the actual neural net. However, as will be shown in Section 8, such linearization can decrease the classification accuracy by $`5\%`$ even on a “CIFAR-2" (airplane V.S. car) dataset. Therefore, exact kernel evaluation is important to study the power of NTK and CNTK.

In this section we define CNN with global average pooling considered in this paper and its corresponding CNTK formula.

CNN definition.

Let $`\vect{x} = \vect{x}^{(0)} \in \mathbb{R}^{\nnw\times \nnh \times \nnc^{(0)}}`$ be the input image and $`\nnc^{(0)}`$ is the number of initial channels.

For $`h=1,\ldots,L`$, $`\beta = 1,\ldots,\nnc^{(h)}`$, the intermediate outputs are defined as

MATH
\begin{align*}
    \tilde{\vect{x}}_{(\beta)}^{(h)} = \sum_{\alpha=1}^{\nnc^{(h-1)}} \mat{W}_{(\alpha),(\beta)}^{(h)} \conv \vect{x}_{(\alpha)}^{(h-1)} ,\quad
    %   \end{align}
    %%  where $\mat{W}_{(\alpha),(\beta)}^{(h)} \sim \gauss(\vect{0},\mat{I}) \in \mathbb{R}^{q^{(h)} \times q^{(h)}}$ 
    %   and
    %   \begin{align*}
    \vect{x}^{(h)}_{(\beta)} = \sqrt{\frac{c_{\sigma}}{\nnc^{(h) \times q^{(h)} \times q^{(h)}}}}\act{\tilde{\vect{x}}_{(\beta)}^{(h)}}.
\end{align*}
Click to expand and view more

The final output is defined as

MATH
\begin{align*}
    f(\params,\vect{x}) =\sum_{\alpha=1}^{\nnc^{(L)}}  W_{(\alpha)}^{(L + 1)}\left(  \frac{1}{\nnw \nnh} \sum_{(i,j) \in [\nnw] \times [\nnh]}\left[\vect{x}_{(\alpha)}^{(L)}\right]_{ij}\right).
\end{align*}
Click to expand and view more

where $`W_{(\alpha)}^{(L + 1)} \in \mathbb{R}`$ is a scalar with Gaussian initialization.

Besides using global average pooling, another modification is that we do not train the first and the layer. This is inspired by in which authors showed that if one applies gradient flow, then at any training time $`t`$, the difference between the squared Frobenius norm of the weight matrix at time $`t`$ and that at initialization is same for all layers. However, note that $`\mat{W}^{(1)}`$ and $`\mat{W}^{(L+1)}`$ are special because they are smaller matrices compared with other intermediate weight matrices, so relatively, these two weight matrices change more than the intermediate matrices during the training process, and this may dramatically change the kernel. Therefore, we choose to fix $`\mat{W}^{(1)}`$ and $`\mat{W}^{(L+1)}`$ to the make over-parameterization theory closer to practice.

CNTK formula.

We let $`\vect{x},\vect{x}'`$ be two input images. Note because CNN with global average pooling and vanilla CNN shares the same architecture except the last layer, $`\mat{\Sigma}^{(h)}(\vect{x},\vect{x}')`$, $`\dot{\mat{\Sigma}}^{(h)}(\vect{x},\vect{x}')`$ and $`\mat{K}^{(h)}(\vect{x},\vect{x}')`$ are the same for these two architectures. the only difference is in calculating the final kernel value. To compute the final kernel value, we use the following procedure.

First, we define $`\mat{\Theta}^{(0)}(\vect{x},\vect{x}') = \vect{0}`$. Note this is different from CNTK for vanilla CNN which uses $`\mat{\Sigma}^{(0)}`$ as the initial value because we do not train the first layer.

For $`h=1,\ldots,L - 1`$ and $`(i,j,i',j') \in [\nnw] \times [\nnh] \times [\nnw] \times [\nnh]`$, we define

MATH
\begin{align*}
\left[\mat{\Theta}^{(h)}(\vect{x},\vect{x}')\right]_{ij,i'j'} = \tr\left(\left[\dot{\mat{K}}^{(h)}(\vect{x},\vect{x}')\odot\mat{\Theta}^{(h-1)}(\vect{x},\vect{x}')+\mat{K}^{(h)}(\vect{x},\vect{x}')\right]_{D_{ij,i'j'}}\right).
% \label{eqn:vanila_cnn_gradient_kernel}
\end{align*}
Click to expand and view more

For $`h=L`$, we define $`\mat{\Theta}^{(L)}(\vect{x},\vect{x}') = \dot{\mat{K}}^{(L)}(\vect{x},\vect{x}')\odot\mat{\Theta}^{(L-1)}(\vect{x},\vect{x}')`$.

Lastly, the final kernel value is defined as

MATH
%   \mat{{\Theta}}^{(H)}(\vect{x},\vect{x}') = 
        %\Theta^{(H)}(\vect{x},\vect{x}')= 
        \frac{1}{\nnw^2 \nnh^2} \sum_{(i,j,i',j') \in [\nnw] \times [\nnh] \times [\nnw] \times [\nnh]} \left[\mat{\Theta}^{(L)}(\vect{x},\vect{x}')\right]_{ij,i'j'} .
    %   +\dot{\mat{K}}^{(H)}(\vect{x},\vect{x}')
    %\right]_{ij,i'j'}.
Click to expand and view more

Note that we ignore $`\mat{K}^{(L)}`$ comparing with the CNTK of CNN. This is because we do not train the last layer. The other difference is we calculate the mean over all entries, instead of calculating the summation over the diagonal ones. This is because we use global average pooling so the cross-variances between every two patches will contribute to the kernel.


📊 논문 시각자료 (Figures)

Figure 1



A Note of Gratitude

The copyright of this content belongs to the respective researchers. We deeply appreciate their hard work and contribution to the advancement of human civilization.

Start searching

Enter keywords to search articles

↑↓
ESC
⌘K Shortcut