BCEWithLogitsLoss¶
Binary Cross Entropy (BCE) measures the distance between the distribution of outcomes and predictions. For stability, Sorix implements BCEWithLogitsLoss, which includes a Sigmoid activation inside the loss function.
The total loss is the average over all $n$ samples in the batch:
$$L = - \frac{1}{n} \sum_{i=1}^{n} [y_i \ln(\sigma(\hat{y}_i)) + (1 - y_i) \ln(1 - \sigma(\hat{y}_i))]$$
Where:
- $\hat{y}_i$ are logarithmic odds (logits).
- $\sigma(x) = \frac{1}{1 + e^{-x}}$ is the Sigmoid function.
- $y_i$ is the target (0 or 1).
Numerical Stability: The Log-Sum-Exp Trick¶
Directly calculating $\ln(\sigma(x))$ can lead to numerical instability. For example, if $x$ is a large positive number, $\sigma(x) \approx 1$, and $\ln(1) = 0$. However, if $x$ is a large negative number, $\sigma(x) \approx 0$, and $\ln(0)$ is undefined ($-\infty$).
To avoid this, Sorix uses a mathematically equivalent but numerically stable form for each element:
Mathematical Derivation¶
We know that $\ln(\sigma(x)) = \ln(\frac{1}{1+e^{-x}}) = -\ln(1+e^{-x})$. And $\ln(1-\sigma(x)) = \ln(\frac{e^{-x}}{1+e^{-x}}) = -x - \ln(1+e^{-x})$.
Substituting these into the BCE formula for a single element: $$l = - [y (-\ln(1+e^{-x})) + (1-y)(-x - \ln(1+e^{-x}))]$$ $$l = y \ln(1+e^{-x}) + (1-y)x + (1-y)\ln(1+e^{-x})$$ $$l = (1-y)x + \ln(1+e^{-x})$$
To make it stable for both large positive and negative $x$, we use the identity $\ln(1+e^{-x}) = \max(-x, 0) + \ln(1+e^{-|x|})$. The final stable per-element loss implemented in Sorix is:
$$l = \max(x, 0) - x \cdot y + \ln(1 + e^{-|x|})$$
And the final loss is the mean of these values: $L = \text{mean}(l)$.
Implementation Optimization¶
Sorix further optimizes this by reusing intermediate values ($e^{-|x|}$) to calculate both the loss and the probabilities needed for the gradient, avoiding redundant exponential calculations.
# Uncomment the next line and run this cell to install sorix
#!pip install 'sorix @ git+https://github.com/Mitchell-Mirano/sorix.git@main'
import numpy as np
from sorix import tensor
from sorix.nn import BCEWithLogitsLoss
# Create logits (+ve for class 1, -ve for class 0)
logits = tensor([100.0, -100.0, 0.0], requires_grad=True)
targets = tensor([1.0, 0.0, 1.0])
criterion = BCEWithLogitsLoss()
loss = criterion(logits, targets)
print(f"Logits (Extremes): {logits.numpy()}")
print(f"Targets: {targets.numpy()}")
print(f"Stable BCE Loss: {loss.item():.4f} (No NaNs or Warnings!)")
Logits (Extremes): [ 100. -100. 0.] Targets: [1. 0. 1.] Stable BCE Loss: 0.2310 (No NaNs or Warnings!)
Verification with Autograd¶
The gradient of this combined function is also remarkably simple and stable: $$\frac{\partial L}{\partial x} = \frac{1}{n}(\sigma(x) - y)$$
This prevents the "vanishing gradient" problem often seen when activation and loss are calculated separately.
loss.backward()
print(f"Gradients w.r.t logits: {logits.grad}")
# dL/d_logit = 1/n * (sigma(logit) - target)
n = logits.data.size
x = logits.data
# Truly stable sigmoid reusing e^{-|x|}
abs_x = np.abs(x)
exp_neg_abs_x = np.exp(-abs_x)
denom = 1 + exp_neg_abs_x
probs = np.where(x >= 0, 1 / denom, exp_neg_abs_x / denom)
manual_grad = (probs - targets.data) / n
print(f"Manual Gradients: {manual_grad}")
Gradients w.r.t logits: tensor([ 0.00000000e+00, 1.26116862e-44, -1.66666667e-01], dtype=sorix.float64) Manual Gradients: [ 0.0000000e+00 1.2611686e-44 -1.6666667e-01]
Training Example¶
Let's see how BCEWithLogitsLoss matches current logits to a desired outcome.
from sorix.optim import SGD
logits = tensor([-5.0], requires_grad=True) # Starting at class 0
target = tensor([1.0]) # Target class 1
optimizer = SGD([logits], lr=0.5)
print(f"Initial Logits: {logits.item():.2f}")
for i in range(21):
loss = criterion(logits, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i % 5 == 0:
# Probability after Sigmoid
prob = 1 / (1 + np.exp(-logits.item()))
print(f"Step {i:2d} | Logit: {logits.item():6.4f} | Prob: {prob:6.4f} | Loss: {loss.item():6.4f}")
print(f"\nFinal Logit: {logits.item():.2f} (Close to +ve for class 1)")
Initial Logits: -5.00 Step 0 | Logit: -4.5033 | Prob: 0.0110 | Loss: 5.0067 Step 5 | Logit: -2.0912 | Prob: 0.1100 | Loss: 2.6300 Step 10 | Logit: -0.1809 | Prob: 0.4549 | Loss: 0.9685 Step 15 | Logit: 0.8868 | Prob: 0.7082 | Loss: 0.3955 Step 20 | Logit: 1.4911 | Prob: 0.8162 | Loss: 0.2221 Final Logit: 1.49 (Close to +ve for class 1)