Convergence rates for shallow neural networks learned by gradient descent (Q6137712)

From MaRDI portal
scientific article; zbMATH DE number 7788892
Language Label Description Also known as
English
Convergence rates for shallow neural networks learned by gradient descent
scientific article; zbMATH DE number 7788892

    Statements

    Convergence rates for shallow neural networks learned by gradient descent (English)
    0 references
    0 references
    0 references
    0 references
    0 references
    0 references
    0 references
    0 references
    16 January 2024
    0 references
    The objective of this study is to explore the following inquiry, with a focus on analyzing the \(L_2\) error: Is it possible to obtain convergence rate outcomes for estimators of one hidden layer neural networks trained through gradient descent within a nonparametric regression framework? To address this question: \begin{itemize} \item Firstly, the authors describe the convergence rate of a simple one hidden layer neural network regression model, where gradient descent is utilized to adjust the weights, under certain conditions related to the Fourier transform (refer to Sec. 1.5). The findings demonstrate that selecting the logistic squasher as the activation function, initializing the network's weights randomly from specific uniform distributions, and executing approximately \(n^{1.75}\) steps of gradient descent with a step size approximately \(1/n^{1.25}\) (up to some logarithmic factors) on a suitably regularized empirical \(L_2\) risk, leads to a truncated forecast that achieves a convergence rate of \(1/\sqrt{n}\). This supports Barron's findings [\textit{A. R. Barron}, Mach. Learn. 14, No. 1, 115--133 (1994; Zbl 0818.68127)], extending this applicability to neural network predictions trained via gradient descent. The proof further elucidates that this achievement largely stems from the strategic initial weight setup and the effective adjustment of the network's external weights through the gradient descent process. Additionally, a minimax lower boundary for the convergence speed is established, indicating that for a high dimensionality \(d\), the achieved convergence rate nearly matches the optimal minimax convergence rate, with an exponent of \(-1/2\). \item They apply their theoretical insights to refine their model. The insight that optimizing the internal weights through gradient descent is not required for their findings suggests that focusing solely on minimizing the network's external weights should be adequate. This approach, with fixed internal weights, essentially reduces to solving a linear least squares problem, where the optimal weights can be determined straightforwardly by resolving a system of linear equations. They introduce a linear least squares estimator that employs randomly chosen internal weights and demonstrate that this estimator achieves the same rate of convergence as the original neural network model trained via gradient descent. A significant benefit of this estimator is its ability to be calculated much more rapidly in practical scenarios. \item They evaluate their theoretically inspired estimator against traditional shallow neural networks trained through gradient descent using simulated data. Frequently, they observe a distinct superiority of their estimator compared to the conventional models. \end{itemize} In conclusion, the authors demonstrate that the success of their results mainly comes from selecting the right initial inner weights and fine-tuning the outer weights using gradient descent. This suggests that they can effectively apply linear least squares methods to determine the outer weights. They validate this approach with theoretical evidence and compare their new linear least squares-based neural network model to traditional neural network models using simulated data. Their experiments reveal that their theory-based model often performs better in various scenarios.
    0 references
    0 references
    deep learning
    0 references
    gradient descent
    0 references
    neural networks
    0 references
    rate of convergence
    0 references
    0 references
    0 references
    0 references
    0 references
    0 references
    0 references