Adapting Newton's Method to Neural Networks through a Summary of Higher-Order Derivatives

6 Dec 2023  ·  Pierre Wolinski ·

We consider a gradient-based optimization method applied to a function $\mathcal{L}$ of a vector of variables $\boldsymbol{\theta}$, in the case where $\boldsymbol{\theta}$ is represented as a tuple of tensors $(\mathbf{T}_1, \cdots, \mathbf{T}_S)$. This framework encompasses many common use-cases, such as training neural networks by gradient descent. First, we propose a computationally inexpensive technique providing higher-order information on $\mathcal{L}$, especially about the interactions between the tensors $\mathbf{T}_s$, based on automatic differentiation and computational tricks. Second, we use this technique at order 2 to build a second-order optimization method which is suitable, among other things, for training deep neural networks of various architectures. This second-order method leverages the partition structure of $\boldsymbol{\theta}$ into tensors $(\mathbf{T}_1, \cdots, \mathbf{T}_S)$, in such a way that it requires neither the computation of the Hessian of $\mathcal{L}$ according to $\boldsymbol{\theta}$, nor any approximation of it. The key part consists in computing a smaller matrix interpretable as a "Hessian according to the partition", which can be computed exactly and efficiently. In contrast to many existing practical second-order methods used in neural networks, which perform a diagonal or block-diagonal approximation of the Hessian or its inverse, the method we propose does not neglect interactions between layers. Finally, we can tune the coarseness of the partition to recover well-known optimization methods: the coarsest case corresponds to Cauchy's steepest descent method, the finest case corresponds to the usual Newton's method.

PDF Abstract

Datasets


Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods


No methods listed for this paper. Add relevant methods here