Optimizer¶
Sorix implements a Fused/Vectorized Optimizer Architecture. Unlike traditional frameworks that iterate over parameters with Python loops, Sorix consolidates all model parameters into contiguous memory buffers.
This design allows Sorix to reach professional-grade performance by reducing CPU overhead and maximizing GPU throughput through massive vectorized operations.
1. Anatomy of the Fused Optimizer¶
Every optimizer in Sorix inherits from the Optimizer base class, which automatically handles the heavy lifting of memory management:
- Parameter Groups (
param_groups): Optimized parameters are organized into groups, each with its own hyperparameters (like learning rate or weight decay). - Master Buffers: For each group, Sorix consolidates parameters into
group['_param_buffer']and gradients intogroup['_grad_buffer']. - Memory Views: Individual parameter attributes (
p.dataandp.grad) are transformed into Views of the group's master buffers. Modifying a parameter directly modifies the buffer and vice versa. - Zero Grad O(1): Clearing gradients is a single atomic operation per group:
group['_grad_buffer'].fill(0). - Vectorized Updates: The update logic is performed on the entire group in a single vectorized instruction.
# Uncomment to install sorix
#!pip install 'sorix @ git+https://github.com/Mitchell-Mirano/sorix.git@main'
2. Implementing a Custom Fused Optimizer¶
Let's implement Fused SignSGD. Instead of scaling each gradient, we move by a fixed step in the direction of the sign of the gradient.
We will demonstrate how a model actually learns using this optimizer by solving a simple linear regression problem.
import numpy as np
from sorix.optim import Optimizer
from sorix import tensor
from sorix.nn import Linear, MSELoss
class FusedSignSGD(Optimizer):
def __init__(self, parameters, lr=0.01):
super().__init__(parameters, lr)
def _perform_step(self):
# Iterate over param_groups for flexibility
for group in self.param_groups:
xp = group['_xp']
p_buf = group['_param_buffer']
g_buf = group['_grad_buffer']
# NO LOOPS on parameters!
# Update the entire GROUP at once using its master buffer.
p_buf -= group['lr'] * xp.sign(g_buf)
# 1. Setup Data: y = 2x1 + 3x2 + 5
X = tensor(np.random.randn(100, 2))
true_weights = np.array([[2.0], [3.0]])
true_bias = 5.0
y = tensor(X.numpy() @ true_weights + true_bias)
# 2. Initialize Model and Optimizer
model = Linear(2, 1)
optimizer = FusedSignSGD(model.parameters(), lr=0.1)
criterion = MSELoss()
print("Training Start...")
print(f"Initial Weights: {model.W.data.ravel()}")
print(f"Initial Bias: {model.b.data.ravel()}")
# 3. Training Loop
for epoch in range(51):
optimizer.zero_grad()
y_pred = model(X)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch:2d} | Loss: {loss.item():.6f}")
print("\nTraining End...")
print(f"Final Weights: {model.W.data.ravel()} (Expected near [2, 3])")
print(f"Final Bias: {model.b.data.ravel()} (Expected near 5)")
Training Start... Initial Weights: [0.9601267 0.6628013] Initial Bias: [0.] Epoch 0 | Loss: 28.530500 Epoch 10 | Loss: 16.068933 Epoch 20 | Loss: 8.683997 Epoch 30 | Loss: 3.863858 Epoch 40 | Loss: 0.967390 Epoch 50 | Loss: 0.004616 Training End... Final Weights: [2.0601268 2.9628007] (Expected near [2, 3]) Final Bias: [5.0999975] (Expected near 5)
3. High-Performance State Management¶
Optimizers like Adam or RMSProp need to track history (moments). In Sorix, we create additional buffers that match the master buffer size to keep everything contiguous in memory.
class FusedMomentum(Optimizer):
def __init__(self, parameters, lr=0.01, momentum=0.9):
super().__init__(parameters, lr)
for group in self.param_groups:
group.setdefault('momentum', momentum)
# Create a momentum buffer per group in its persistent state
group['state']['v_buffer'] = group['_xp'].zeros_like(group['_param_buffer'])
def _perform_step(self):
for group in self.param_groups:
p_buf = group['_param_buffer']
g_buf = group['_grad_buffer']
v_buf = group['state']['v_buffer']
# All momentum calculations happen in bulk for the entire group
# v = m * v + g
v_buf[:] = group['momentum'] * v_buf + g_buf
p_buf -= group['lr'] * v_buf
print("FusedMomentum implemented with vectorized state management!")
FusedMomentum implemented with vectorized state management!
4. Verification: The "View" Principle¶
A key feature of Sorix is that p.data is just a view of the optimizer.param_buffer.
Let's prove it: if we modify the master buffer, the model weights change instantly without any assignment.
model = Linear(2, 2)
optimizer = FusedSignSGD(model.parameters(), lr=0.01)
print(f"Original Weight [0,0]: {model.W.data[0,0]:.4f}")
# Manually modify the START of the first group's master buffer
optimizer.param_groups[0]['_param_buffer'][0] = 999.0
print(f"New Weight [0,0]: {model.W.data[0,0]:.4f} (Updated automatically via View!)")
Original Weight [0,0]: 1.0078 New Weight [0,0]: 999.0000 (Updated automatically via View!)
Conclusion¶
Sorix's Fused Optimizer architecture provides professional performance with a simple API. By manipulating self.param_buffer and self.grad_buffer, you can implement state-of-the-art algorithms that run with maximum efficiency on both CPU and GPU.