Source code for src.canns.trainer.oja
"""Oja's normalized Hebbian learning trainer."""
from __future__ import annotations
from collections.abc import Iterable
import brainpy.math as bm
import jax.numpy as jnp
from ..models.brain_inspired import BrainInspiredModel
from ._base import Trainer
from .utils import normalize_weight_rows
__all__ = ["OjaTrainer"]
[docs]
class OjaTrainer(Trainer):
"""
Oja's normalized Hebbian learning trainer.
Oja's rule stabilizes pure Hebbian growth by introducing a weight-dependent
normalization term, enabling single-neuron principal component extraction
without unbounded weight magnitudes.
Learning Rule:
ΔW_ij = η * (y_i * x_j - y_i^2 * W_ij)
where:
- W_ij is the weight from input j to output i
- x_j is the input activity
- y_i is the output activity (y = W @ x)
- η is the learning rate
The rule can be rewritten as:
ΔW = η * (y @ x^T - diag(y^2) @ W)
This naturally leads to weight normalization and PCA extraction.
Reference:
Oja, E. (1982). Simplified neuron model as a principal component analyzer.
Journal of Mathematical Biology, 15(3), 267-273.
"""
def __init__(
self,
model: BrainInspiredModel,
learning_rate: float = 0.01,
normalize_weights: bool = True,
weight_attr: str = "W",
compiled: bool = True,
**kwargs,
):
"""
Initialize Oja trainer.
Args:
model: The model to train (typically LinearLayer)
learning_rate: Learning rate η for weight updates
normalize_weights: Whether to normalize weights to unit norm after each update
weight_attr: Name of model attribute holding the connection weights
compiled: Whether to use JIT-compiled training loop (default: True)
**kwargs: Additional arguments passed to parent Trainer
"""
super().__init__(model=model, **kwargs)
[docs]
self.learning_rate = learning_rate
[docs]
self.normalize_weights = normalize_weights
[docs]
self.weight_attr = weight_attr
[docs]
self.compiled = compiled
[docs]
def train(self, train_data: Iterable):
"""
Train the model using Oja's rule.
Args:
train_data: Iterable of input patterns (each of shape (input_size,))
"""
# Get weight parameter
weight_param = getattr(self.model, self.weight_attr, None)
if weight_param is None or not hasattr(weight_param, "value"):
raise AttributeError(
f"Model does not have a '{self.weight_attr}' parameter with .value attribute"
)
if self.compiled:
self._train_compiled(train_data, weight_param)
else:
self._train_uncompiled(train_data, weight_param)
def _train_compiled(self, train_data: Iterable, weight_param):
"""
JIT-compiled training loop using bp.transform.scan.
Args:
train_data: Iterable of input patterns
weight_param: Weight parameter object
"""
# Convert patterns to array for JIT compilation
patterns = jnp.stack([jnp.asarray(p, dtype=jnp.float32) for p in train_data])
# Initial weights
W_init = jnp.asarray(weight_param.value, dtype=jnp.float32)
# Training step for single pattern
def train_step(W, x):
# Compute output: y = W @ x
y = W @ x
# Oja's rule: ΔW = η * (y @ x^T - diag(y^2) @ W)
outer_product = jnp.outer(y, x)
normalization = jnp.outer(y * y, jnp.ones_like(x)) * W
delta_W = self.learning_rate * (outer_product - normalization)
W = W + delta_W
# Optional: normalize weights to unit norm
if self.normalize_weights:
W = normalize_weight_rows(W)
return W, None
# Run compiled scan
W_final, _ = bm.scan(train_step, W_init, patterns)
# Update model parameters
weight_param.value = W_final
def _train_uncompiled(self, train_data: Iterable, weight_param):
"""
Python loop training (fallback, slower but more flexible).
Args:
train_data: Iterable of input patterns
weight_param: Weight parameter object
"""
W = weight_param.value
# Process each pattern
for pattern in train_data:
x = jnp.asarray(pattern, dtype=jnp.float32)
# Compute output: y = W @ x
y = W @ x
# Oja's rule: ΔW = η * (y @ x^T - diag(y^2) @ W)
outer_product = jnp.outer(y, x)
normalization = jnp.outer(y * y, jnp.ones_like(x)) * W
delta_W = self.learning_rate * (outer_product - normalization)
W = W + delta_W
# Optional: normalize weights to unit norm
if self.normalize_weights:
W = normalize_weight_rows(W)
# Update model weights
weight_param.value = W
[docs]
def predict(self, pattern, *args, **kwargs):
"""
Predict output for a single input pattern.
Args:
pattern: Input pattern of shape (input_size,)
Returns:
Output pattern of shape (output_size,)
"""
if hasattr(self.model, "forward"):
return self.model.forward(pattern)
else:
# Fallback: direct computation
weight_param = getattr(self.model, self.weight_attr)
x = jnp.asarray(pattern, dtype=jnp.float32)
return weight_param.value @ x