Adam¶
Adam is a sophisticated optimization algorithm that combines the elements of RMSprop and Momentum. It computes adaptive learning rates for each parameter by incorporating both first-order moments (the mean) and second-order moments (the uncentered variance) of the gradients.
Mathematical definition¶
Adam maintains two moving averages: $m_t$ (first moment) and $v_t$ (second moment). The update rules are:
$$ m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot \nabla \mathcal{L}(\theta_t) $$ $$ v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot (\nabla \mathcal{L}(\theta_t))^2 $$
Bias-corrected estimates are calculated to account for the initialization to zero at earlier time steps:
$$ \hat{m}_t = \frac{m_t}{1 - \beta_1^t} $$ $$ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} $$
The parameters are updated using:
$$ \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t $$
where:
- $\beta_1, \beta_2$: Exponential decay rates for the moment estimates (typically 0.9 and 0.999).
- $t$: Time step (iteration count).
- $\epsilon$: Small constant for stability.
- $\eta$: Learning rate ($lr$).
Implementation details¶
In Sorix, the Adam optimizer provides high computational efficiency and low memory overhead. It stores the state for $m_t$ and $v_t$ as lists and automatically performs device-specific calculations (CPU or GPU).
# Uncomment the next line and run this cell to install sorix
#!pip install 'sorix @ git+https://github.com/Mitchell-Mirano/sorix.git@main'
import numpy as np
from sorix import tensor
from sorix.optim import Adam
import sorix
# Miniizing an anisotropic function with Adam: f(x, y) = x^2 + 10*y^2
# Adam combines momentum with adaptive scaling, making it exceptionally reliable
# even with high learning rates on non-homogeneous surfaces.
x = tensor([5.0], requires_grad=True)
y = tensor([5.0], requires_grad=True)
optimizer = Adam([x, y], lr=0.5)
for epoch in range(10):
loss = x * x + tensor([10.0]) * y * y
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}: x = {x.data[0]:.4f}, y = {y.data[0]:.4f}, loss = {loss.data[0]:.4f}")
Epoch 1: x = 4.5000, y = 4.5000, loss = 275.0000 Epoch 2: x = 4.0021, y = 4.0021, loss = 222.7500 Epoch 3: x = 3.5079, y = 3.5079, loss = 176.1814 Epoch 4: x = 3.0197, y = 3.0197, loss = 135.3614 Epoch 5: x = 2.5398, y = 2.5398, loss = 100.3042 Epoch 6: x = 2.0712, y = 2.0712, loss = 70.9574 Epoch 7: x = 1.6171, y = 1.6171, loss = 47.1878 Epoch 8: x = 1.1813, y = 1.1813, loss = 28.7653 Epoch 9: x = 0.7679, y = 0.7679, loss = 15.3507 Epoch 10: x = 0.3812, y = 0.3812, loss = 6.4868