论文阅读:Method of Successive Approximations

Original Paper: [1803.01299] An Optimal Control Approach to Deep Learning and Applications to Discrete-Weight Neural Networks


The Optimal Control Viewpoint

Neural Networks as Dynamical System: Let $T \in \mathbb{Z}_+$ denote the number layers and $\{x_{s,0} \in \mathbb{R}^{d_0}: s = 0, 1, \cdots, S\}$ represent $S + 1$ inputs (images, time-series). Consider the dynamical system

$$ x_{s,t+1}=f_t(x_{s,t},\theta_t),\quad t=0,1,\ldots,T-1, $$

where $f_{t}:\mathbb{R}^{d_{t}}\times\Theta_{t}\to\mathbb{R}^{d_{t+1}}$ is a transformation on the state, $\Theta_t$ is the trainable parameter set.

Objective of Training: The goal of training is to adjust the weights $\bm{\theta} := \{\theta_t: t = 0, 1, \cdots, T-1\}$ to minimize some loss function between final output $x_{s, T}$ and true targets $y_s$ of $x_{s,0}$.

Statement of Problem: Define $\Phi_s: \mathbb{R}^{d_T} \to \mathbb{R}$ that measures the loss, and the average loss function is

$$ \frac{1}{S}\sum_s \Phi_s(x_{s,T}) $$

We also consider some regularization terms $L_t: \mathbb{R}^{d_t} \times \Theta_t \to \mathbb{R}$, thus the problem is

$$ \min_{\boldsymbol{\theta}\in\boldsymbol{\Theta}}J(\boldsymbol{\theta}):=\frac{1}{S}\sum_{s=1}^{S}\Phi_{s}(x_{s,T})+\frac{1}{S}\sum_{s=1}^{S}\sum_{t=0}^{T-1}L_{t}(x_{s,t},\theta_{t}), $$

$$ x_{s,t+1}=f_t(x_{s,t},\theta_t), \quad t=0,\cdots,T-1, \quad s\in \{1,2,\cdots,S\} $$

where $\bm{\Theta} := \{\Theta_0 \times \cdots \times \Theta_{T-1}\}$.


The Pontryagin’s Maximum Principle

Hamiltonian Function: Let $\bm{\theta}^\ast = \{\theta_0, \cdots, \theta_{T-1}\} \in \bm{\Theta}$ be a solution of the problem. For each $t$, define the Hamiltonian function $H_{t}:\mathbb{R}^{d_{t}}\times\mathbb{R}^{d_{t+1}}\times\Theta_{t}\to\mathbb{R}$​

$$ H_t(x,p,\theta):=p\cdot f_t(x,\theta)-\frac{1}{S}L_t(x,\theta). $$

where $p \in \mathbb{R}^{d_{t+1}}$ is the co-state vector.

Discrete PMP, Informal Statement: Let $f_t$ and $\Phi_s, s = 1,2,\cdots, S$ be sufficiently smooth in $x$. Assume for each $t$ and $x \in \mathbb{R}^{d_t}$, the sets $\{f_t(x,\theta): \theta \in \Theta_t\}$ and $\{L_t(x,\theta): \theta \in \Theta_t\}$ are convex. Then there exists $\boldsymbol{p}_{s}^{*}:=\{p_{s,t}^{*}:t=0,\ldots,{T}\},$ such that

$$ x_{s,t+1}^* = \nabla_p H_t(x_{s,t}^*, p_{s,t+1}^*, \theta_t^*), \quad x_{s,0}^* = x_{s,0} $$

$$ p_{s,t}^* = \nabla_x H_t(x_{s,t}^*, p_{s,t+1}^*, \theta_t^*), \quad p_{s,T}^* = -\frac{1}{S} \nabla \Phi_s(x_{s,T}^*) $$

$$ \sum_{s=1}^S H_t(x_{s,t}^*, p_{s,t+1}^*, \theta_t^*) \geq \sum_{s=1}^S H_t(x_{s,t}^*, p_{s,t+1}^*, \theta), \forall \theta \in \Theta_t $$

for $t = 0, 1, \cdots, T-1$ and $s = 1,2,\cdots, S$.


The Method of Successive Approximations (MSA)

Statement of MSA Algorithm: Start from an initial guess $\boldsymbol{\theta}^{0}=\{\theta_{t}^{0}\in\Theta_{t}:t=0\ldots,T-1\}$,

  • State Equation: $x_{s, t}$ means the state of the $s$-th sample at the $t$-th layer, $f_t$ is the transformation function at the $t$-th layer, $\theta_t$ is the control at the $t$-th layer

$$ x_{s,t+1}^{\boldsymbol{\theta}^0}=f_t(x_{s,t}^{\boldsymbol{\theta}^0},\theta_t^0),\quad x_{s,0}^{\boldsymbol{\theta}^0}=x_{s,0}, $$

  • Co-State Equation: $p_{s,t}$ means the co-state of the $s$-th sample at the $t$-th layer, $\Phi_s$ measures the loss of the $s$-th sample, $H_t$ is the Hamiltonian function

$$ p_{s,t}^{\boldsymbol{\theta}^0}=\nabla_xH_t(x_{s,t}^{\boldsymbol{\theta}^0},p_{s,t+1}^{\boldsymbol{\theta}^0},\theta_t^0),\quad p_{s,T}^{\boldsymbol{\theta}^0}=-\frac{1}{S}\nabla\Phi_s(x_{s,T}^{\boldsymbol{\theta}^0}), $$

  • Maximization of the Hamiltonian:

$$ \theta_t^1=\arg\max_{\theta\in\Theta_t}\sum_{s=1}^SH_t(x_{s,t}^{\boldsymbol{\theta}^0},p_{s,t+1}^{\boldsymbol{\theta}^0},\theta), \quad t=0,\ldots,T-1. $$

MSA Algorithm:

  • Initialize: $\boldsymbol{\theta}^{0}=\{\theta_{t}^{0}\in\Theta_{t}:t=0\ldots,T-1\};$

  • For $k = 0$ to $K$ do

    • $x_{s,t+1}^{\boldsymbol{\theta}^k}=f_t(x_{s,t}^{\boldsymbol{\theta}^k},\theta_t^k)$ and $x_{s,0}^{\boldsymbol{\theta}^k}=x_{s,0}$ for all $s$ and $t$;
    • $p_{s,t}^{\boldsymbol{\theta}^k}=\nabla_xH_t(x_{s,t}^{\boldsymbol{\theta}^k},p_{s,t+1}^{\boldsymbol{\theta}^k},\theta_t^k),p_{s,T}^{\boldsymbol{\theta}^k}=-\frac{1}{S}\nabla\Phi_s(x_{s,T})$ for all $s$ and $t$;
    • $\theta_t^{k+1}=\arg\max_{\theta\in\Theta_t}\sum_{s=1}^SH_t(x_{s,t}^{\boldsymbol{\theta}^k},p_{s,t+1}^{\boldsymbol{\theta}^k},\theta)$ for $t = 0, \cdots, T-1$;
  • End for