Skip to Book Content
Book cover image

Chapter 6 - Learning Rate and Momentum

Neural Smithing: Supervised Learning in Feedforward Artificial Neural Networks
Russell D. Reed and Robert J. Marks II
Copyright © 1999 Massachusetts Institute of Technology
 

6.2 Momentum

Back-propagation with momentum can be viewed as gradient descent with smoothing. The idea is to stabilize the weight trajectory by making the weight change a combination of the gradient-decreasing term in equation 5.23 plus a fraction of the previous weight change. The weight update rule is [330]

Click To expand
Figure 6.9: With momentum, the current weight change Δ w (t) is a combination of a step down the negative gradient, plus a fraction 0 α <1 of the previous weight change. For α 1, opposing weight change components (horizontal) approximately cancel while complementary components (vertical) sum, producing a smoothing effect on the weight trajectory
(6.5)

This is sometimes called the generalized delta rule. The weight change Δw(t) is a combination of a step down the negative gradient, plus a fraction 0 α < 1 of the previous weight change (figure 6.9). For α 1, opposing weight change components approximately cancel while complementary components sum, leading to a smoothing effect on the weight trajectory. When successive gradients point in the same direction, the terms reinforce each other, leading to accelerated learning.

The smoothing effect of momentum can be illustrated by expansion of (6.5)

(6.6)Click To expand

That is, with momentum the weight update is an exponential average of all the previous gradient terms rather than just the most recent term. Because α < 1, the contribution from earlier derivative terms decays with each time step and the sum is dominated by the more recent terms. The time-constant of the system is controlled by α. For small α, the coefficients decay quickly as k increases so the system "forgets" earlier terms quickly. For large α 1-, however, the coefficients decay very slowly and the system has a long memory; the system will be stable but slow to react to changes in the error term.

The learning accelerating effect of momentum can be illustrated by considering the case where the derivative is constant, Ε/w(t)= J. This is a reasonable approximation when is very small so w does not change much with each step. It is also reasonable on flat areas of E(w) where the gradient is small. Then

(6.7)

where the identity

= 1/(1 - α) (for |α| < 1) is used in the last step. Without momentum, Δw(t) would be -ηJ. With momentum, however, Δw(t)=-ηJ/(1-α). Momentum thus has the effect of amplifying the learning rate from η to the effective value η'=η/(1-α).

As α 1, the effective learning rate can become very large, but the time constant also becomes large. Weight changes are affected by error information from many past cycles, which may make it difficult for the system to respond quickly to new conditions in the E(w) surface. The weight trajectory may coast over a minimum but be unable to stop because of the continuing effects of earlier weight changes.

6.2.1 Effects of Momentum

As noted earlier, batch-mode back-propagation with a small learning rate is an approximation of gradient descent. Two problems with gradient descent are (1) when the learning rate is small, progress may be very slow (figure 6.10a) and (2) when the learning rate is too large and the error surface contains "ravines," the weight vector may oscillate wildly from one side of the valley to the other while creeping slowly along the length of the valley to the minimum, an effect sometimes called cross-stitching (figure 6.10b). Upon reaching the neighborhood of a minimum, it may overshoot many times before settling down.

Figure 6.10: Gradient descent trajectories: (a) with a small step size (0.01), pure gradient descent follows a smooth trajectory, but progress may be very slow; (b) cross-stitching When the step size is too big (0.1) and the error surface has valleys, the trajectory may oscillate wildly from one side to the other while creeping slowly along the length of the valley to the minimum. Upon reaching the neighborhood of a minimum, it may overshoot many times before settling down.

Briefly, momentum has the following effects:

  • It smooths weight changes by filtering out high frequency variations. When the learning rate is too high, momentum tends to suppress cross-stitching because consecutive opposing weight changes tend to cancel. The side to side oscillations across the valley damp out leaving only the components along the axis of the valley, which add up.

  • When a long sequence of weight changes are all in the same direction, momentum tends to amplify the effective learning rate to η'=η/(1-α), leading to faster convergence.

  • Momentum may sometimes help the system escape small local minima by giving the state vector enough inertia to coast over small bumps in the error surface.

Cross-stitching is a problem for gradient descent when the learning rate is too large and error surface has steep-sided ravines that have a shallow slope along the axis. (This can be stated more technically in terms of the eigenvalues of the Hessian matrix, see section A.2.) Without momentum, the network has only the gradient information to guide its path. Because the gradient on one side of a steep valley points almost directly across the valley and has only a very small component along it, the weight vector tends to jump back and forth across the valley and progress along the axis of the valley is slow relative to the size of the weight changes. Also, upon reaching the neighborhood of a minimum, the weight vector may overshoot many times before settling down. With the momentum term, side to side oscillations tend to cancel but steps along the axis sum so progress along the valley is faster. The network can follow the path of the ravine better, so the learning rate can often be increased, leading to faster training times. Learning acceleration and oscillation dampening effects of momentum can be seen in comparing Figures 6.10 and 6.11..

Figure 6.11: Gradient descent trajectories with momentum. With momentum, opposite (side to side) changes tend to cancel while complementary changes (along the length of the valley) tend to sum. The overall effect is to stabilize the oscillations and accelerate convergence. (a) When the step size is small, momentum acts to accelerate convergence (step size 0.01 and momentum 0.99, cf. figure 6.10a). (b) Small amounts of momentum also help to damp oscillations when the step size is too large (step size 0.1 and momentum 0.2, cf. figure 6.10b).

Too Much Momentum With momentum, the state vector has a tendency to keep moving in the same direction. Weight changes are affected by error information from many past cycles--the larger the momentum, the stronger the lingering influence of previous changes. In effect, momentum gives the weight state inertia, which "keeps the marble rolling," allowing it to coast over flat spots and perhaps out of small local minima.

A little inertia is useful for stabilization but too much may make the system sluggish; it may overshoot good minima or be unable to follow a curved valley in the error surface. The system may coast past minima and out onto high plateaus where it becomes stuck (section 6.2.2).

Interaction with Learning Rate Many studies have claimed that momentum tends to make choice of learning rate η less critical. When η is too small, successive weight updates tend to be in the same direction and momentum effectively amplifies η to η/(1-α). When η is too large, successive updates tend to be in nearly opposite directions and momentum causes them to cancel out, effectively reducing the learning rate. Some support is seen in figure 6.1 where the probability of convergence density function is wider (on a logarithmic scale) for α = 0.9 than for α = 0. On a linear scale, however, the width of the density function decreases in agreement with results that show that momentum reduces the stable range of learning rates for the LMS algorithm [343], [326].

Click To expand
Figure 6.12: At the low end of the momentum range, α increases generally lead to faster convergence. Here the E(t) curves are smooth. Occasional error spikes may occur but the system recovers quickly. All trajectories start from the same random weights.

6.2.2 Typical E(t) Curves with Momentum

Figures 6.12 and 6.13 show E(t) curves for various momentum values and fixed learning rate. All curves were generated from the same initial weight vector as in figures 6.4 and 6.5. Simulation details are described in section 6.1.1.

At low momentum values, the E(t) curves are smooth and larger values of α lead to faster convergence (assuming reasonable learning rates). Occasional spikes may occur but the system recovers quickly. The curves in figure 6.12 are all qualitatively similar so the increased convergence speed may be due mostly to amplification of the effective learning rate. That is, the system appears to be following the same basic trajectory at varying speeds.

As α increases past a certain point, however, convergence becomes unreliable (figure 6.13). At α = 0.6 the system converges quickly, but at larger values it becomes stuck in a poor minimum. For α = 0.99, E(t) oscillates strongly and the system jumps from a poor minimum to a worse one at about t = 100.

6.2.3 Small-Signal Analysis, Momentum Only

The acceleration and smoothing effects of momentum can be explained in terms of smallsignal analyses. The weight update equation with momentum is

(6.8)
Click To expand
Figure 6.13: E(t) curves for large momentum values. Convergence becomes unreliable when the momentum is too large for the learning rate η. This system converges quickly with α = 0.6 but not with higher values. With α = 0.99, E(t) oscillates strongly and the system jumps from a poor minimum to a worse one at about t = 100. All trajectories start from the same random weights.

Convert this discrete-time iteration to a continuous-time system by the approximation

and assume Δt = 1. Then

(6.9)

where J =E/w. Another discrete-time to continuous-time approximation for the second derivative gives

(6.10)

and

(6.11)

This second-order differential equation is easily solved for certain special cases. Laplace transforms (e.g., [292]) are used in the following discussions.

Impulse Response Assume J is an impulse at 0, that is, J(t) = Joδ(t). This approximates the case of encountering a "cliff" in the E(w) surface, where J is large, and then coasting on a flat plateau where J 0. Taking Laplace transforms gives,

(6.12)

where j = α/(1- α). For 0 < α < 1, this has the solution

(6.13)

For α 1, the solution is unstable. Otherwise w (t) asymptotically approaches a final value w() = w(0) -ηJo/(1- α). Instead of taking a single step ηJoat t=0, it takes many steps that asymptotically add up to a value 1/(1 - α) times as large.

Momentum thus has the effect of amplifying the learning rate from η to the effective value η' = η/(1-α). For α 1, the effective learning rate can become very large, but the time constant r also becomes large, which may make it difficult for the system to respond quickly to new conditions in the E(w) surface.

Step Response Assume J is a step function at t = 0, that is, J(t) = Jou(t) where u(t) is the unit step function

(6.14)

This is reasonable when τ is small and the local error surface is nearly flat; the gradient changes very little in one iteration and can be approximated by a constant. Laplace transforms give

(6.15)

where r = &α/(1-α) again. The solution is unstable for α 1. For 0 < α < 1, this has the solution

(6.16)

This approximates a ramp functiona linear rise with t plus a transient term similar to (6.13). For t >> j

(6.17)
Click To expand
Figure 6.14: Oscillation in E(t) due to momentum. With momentum, the weight vector has inertia, which allows it to coast up hillsides in the E(w) surface. The larger the momentum and the smaller the learning rate, the farther it can rise and the longer it takes to stop. This, in combination with the E(w) surface, can lead to oscillation. Mathematically, smaller &η and larger α give the dynamic system a longer time-constant, visible here in a lower oscillation frequency. (These values were chosen to exaggerate the effect; they are not necessarily recommended.)

Frequency Response The common term

corresponds to a leaky integrator and has a low-pass frequency response. (A leaky integrator is an exponentially weighted averager.) This is another justification for the statement that momentum helps filter out high frequency oscillations in the weight changes. Recall that = τ = α/(1 - α) becomes large as α 1. The time-domain impulse response is

h(t) α e-t/τ

Convolution of a signal J(t) with h (t) yields an exponentially weighted average

which weights recent J values more heavily than older values.

Generally, smaller η and larger α values give the dynamic system a longer time-constant, visible in figure 6.14 as a lower frequency of oscillation.

6.2.4 Small-Signal Analysis, Momentum, and Weight Decay

It is relatively easy to extend these linear small-signal analyses to include weight decay terms. The new weight update equation is

(6.18)

where 0 ρ << 1 is the weight decay parameter. The same discrete-time to continuous-time approximations give

(6.19)

Impulse Response When J is an impulse at 0, that is, J(t) = Joδ(t), Laplace transforms give

(6.20)

The denominator has roots

(6.21)

where β = 1 - α. For β2 4ρ α > 0, both roots are real. Critical damping occurs at

(6.22)

This is a decreasing function of α. As α 1, ρ must approach 0 to prevent oscillation. For ρ < ρ0, both roots are real. For larger values, the roots are complex and the solution is an exponentially decaying sinusoid. Convergence of these types of systems is usually fastest when the system is slightly underdamped.

Step Response When J is a step function at t = 0, that is, J(t) = J0u(t), Laplace transforms give

(6.23)

Similar arguments are made by Bailey [15] using the damped oscillator equation

(6.24)

Here b = (1 -α)/α and k = ρ / α. They require b < / N in order to average over all N patterns in the training set and choose b = 1 (2N) as a reasonable value. This corresponds to α = 2N / (2N + 1). Critical damping occurs when k = (b/2)2, which corresponds to the value in equation 6.22.