src.canns.trainer¶
Training utilities for CANNs models.
The module exposes the abstract Trainer base class and concrete implementations
of classic brain-inspired learning algorithms: HebbianTrainer, AntiHebbianTrainer,
OjaTrainer, BCMTrainer, SangerTrainer, and STDPTrainer.
Submodules¶
Classes¶
Anti-Hebbian trainer for pattern decorrelation and unlearning. |
|
BCM (Bienenstock-Cooper-Munro) sliding-threshold plasticity trainer. |
|
Generic Hebbian trainer with progress reporting. |
|
Oja's normalized Hebbian learning trainer. |
|
STDP (Spike-Timing-Dependent Plasticity) trainer. |
|
Sanger's rule (Generalized Hebbian Algorithm) for multiple PC extraction. |
|
Abstract base class for training utilities in CANNs. |
Package Contents¶
- class src.canns.trainer.AntiHebbianTrainer(model, **kwargs)[source]¶
Bases:
HebbianTrainerAnti-Hebbian trainer for pattern decorrelation and unlearning.
Overview - Implements anti-Hebbian learning rule: “Neurons that fire together, wire apart” - Uses negative weight updates:
W <- W - Σ (t t^T)instead of positive - Inherits all functionality from HebbianTrainer (predict, predict_batch, etc.)Applications - Sparse coding and independent component analysis - Competitive learning networks - Decorrelation and whitening of feature representations - Lateral inhibition modeling - Selective forgetting / pattern unlearning
Learning Rule - For patterns
x, compute optional mean activityrhoand update:W <- W - sum_i (x_i - rho)(x_i - rho)^T(note the minus sign)If
subtract_mean=True, patterns are centered by mean:t = x - rhoIf
normalize_by_patterns=True, divide by number of patternsAll options from HebbianTrainer apply (subtract_mean, zero_diagonal, etc.)
- Example
>>> model = AmariHopfieldNetwork(num_neurons=100, activation="tanh") >>> # Train with Hebbian first >>> hebb_trainer = HebbianTrainer(model) >>> hebb_trainer.train(all_patterns) >>> # Then apply anti-Hebbian to unlearn specific pattern >>> anti_trainer = AntiHebbianTrainer(model, subtract_mean=False) >>> anti_trainer.train([pattern_to_forget])
Initialize Anti-Hebbian trainer.
- Parameters:
model (src.canns.models.brain_inspired.BrainInspiredModel) – The model to train
**kwargs – Additional arguments passed to HebbianTrainer
- class src.canns.trainer.BCMTrainer(model, learning_rate=0.01, weight_attr='W', compiled=True, **kwargs)[source]¶
Bases:
src.canns.trainer._base.TrainerBCM (Bienenstock-Cooper-Munro) sliding-threshold plasticity trainer.
The BCM rule uses a dynamic postsynaptic threshold to switch between potentiation and depression based on recent activity, yielding stable receptive-field development and experience-dependent refinement.
- Learning Rule:
ΔW_ij = η * y_i * (y_i - θ_i) * x_j
- where:
W_ij is the weight from input j to neuron i
x_j is the presynaptic activity
y_i is the postsynaptic activity
θ_i is the modification threshold for neuron i
- The threshold θ evolves as a sliding average:
θ_i = <y_i^2>
- This creates two regimes:
If y > θ: potentiation (LTP, strengthen synapses)
If y < θ: depression (LTD, weaken synapses)
- Reference:
Bienenstock, E. L., Cooper, L. N., & Munro, P. W. (1982). Theory for the development of neuron selectivity. Journal of Neuroscience, 2(1), 32-48.
Initialize BCM trainer.
- Parameters:
model (src.canns.models.brain_inspired.BrainInspiredModel) – The model to train (typically LinearLayer with use_bcm_threshold=True)
learning_rate (float) – Learning rate η for weight updates
weight_attr (str) – Name of model attribute holding the connection weights
compiled (bool) – Whether to use JIT-compiled training loop (default: True)
**kwargs – Additional arguments passed to parent Trainer
- predict(pattern, *args, **kwargs)[source]¶
Predict output for a single input pattern.
- Parameters:
pattern – Input pattern of shape (input_size,)
- Returns:
Output pattern of shape (output_size,)
- train(train_data)[source]¶
Train the model using BCM rule.
- Parameters:
train_data (collections.abc.Iterable) – Iterable of input patterns (each of shape (input_size,))
- compiled = True¶
- learning_rate = 0.01¶
- weight_attr = 'W'¶
- class src.canns.trainer.HebbianTrainer(model, show_iteration_progress=False, compiled_prediction=True, *, weight_attr='W', subtract_mean=True, zero_diagonal=True, normalize_by_patterns=True, prefer_generic=True, state_attr=None, prefer_generic_predict=True, preserve_on_resize=True)[source]¶
Bases:
src.canns.trainer._base.TrainerGeneric Hebbian trainer with progress reporting.
Overview - Uses a model-exposed weight parameter (default attribute name:
W) to apply astandard Hebbian update. If unavailable, falls back to the model’s
apply_hebbian_learning.Works with models that expose a parameter object with a
.valuendarray of shape (N, N) (e.g.,bm.Variable).
Generic rule - For patterns
x(shape: (N,)), compute optional mean activityrhoand updateW <- W + sum_i (x_i - rho)(x_i - rho)^T.Options allow zeroing the diagonal and normalizing by number of patterns.
Key options -
weight_attr: Name of the weight attribute on the model (default: “W”). -subtract_mean: Whether to center patterns by mean activityrho. -zero_diagonal: Whether to set diagonal ofWto zero after update. -normalize_by_patterns: Divide accumulated outer-products by number of patterns. -prefer_generic: Prefer the generic Hebbian rule over model-specific method. -state_attr: Name of the state vector attribute for prediction (default:s; ormodel-provided
predict_state_attr).prefer_generic_predict: Prefer the trainer’s generic predict loop over the model’spredictimplementation (falls back automatically when unsupported).
Initialize Hebbian trainer.
- Parameters:
model (src.canns.models.brain_inspired.BrainInspiredModel) – The model to train
show_iteration_progress (bool) – Whether to show progress for individual pattern convergence
compiled_prediction (bool) – Whether to use compiled prediction by default (faster but no iteration progress)
weight_attr (str | None) – Name of model attribute holding the connection weights (default: “W”).
subtract_mean (bool) – Subtract dataset mean activity (rho) before outer-product.
zero_diagonal (bool) – Force zero self-connections after update.
normalize_by_patterns (bool) – Divide accumulated outer-products by number of patterns.
prefer_generic (bool) – If True, use trainer’s generic Hebbian rule when possible; otherwise call the model’s own implementation if available.
- predict(pattern, num_iter=20, compiled=None, show_progress=None, convergence_threshold=1e-10, progress_callback=None)[source]¶
Predict a single pattern.
- Parameters:
- Returns:
Predicted pattern
- predict_batch(patterns, num_iter=20, compiled=None, show_sample_progress=True, show_iteration_progress=None, convergence_threshold=1e-10)[source]¶
Predict multiple patterns with progress reporting.
- Parameters:
patterns (list) – List of input patterns to predict
num_iter (int) – Maximum number of iterations per pattern
compiled (bool | None) – Override default compiled setting
show_sample_progress (bool) – Whether to show progress across samples
show_iteration_progress (bool | None) – Override default iteration progress setting
convergence_threshold (float) – Energy change threshold for convergence
- Returns:
List of predicted patterns
- Return type:
- train(train_data)[source]¶
Train the model using Hebbian learning.
Behavior - Preferred path: apply a generic Hebbian update directly to
model.<weight_attr>. - Fallback path: callmodel.apply_hebbian_learning(train_data)if generic pathis unavailable.
Requirements for generic path - Model must expose
model.<weight_attr>with a.valuearray of shape (N, N). - Optionally, models can declareweight_attrproperty to specify theattribute name, allowing
HebbianTrainer(..., weight_attr=None).
- normalize_by_patterns = True¶
- prefer_generic = True¶
- prefer_generic_predict = True¶
- preserve_on_resize = True¶
- state_attr = None¶
- subtract_mean = True¶
- weight_attr = 'W'¶
- zero_diagonal = True¶
- class src.canns.trainer.OjaTrainer(model, learning_rate=0.01, normalize_weights=True, weight_attr='W', compiled=True, **kwargs)[source]¶
Bases:
src.canns.trainer._base.TrainerOja’s normalized Hebbian learning trainer.
Oja’s rule stabilizes pure Hebbian growth by introducing a weight-dependent normalization term, enabling single-neuron principal component extraction without unbounded weight magnitudes.
- Learning Rule:
ΔW_ij = η * (y_i * x_j - y_i^2 * W_ij)
- where:
W_ij is the weight from input j to output i
x_j is the input activity
y_i is the output activity (y = W @ x)
η is the learning rate
- The rule can be rewritten as:
ΔW = η * (y @ x^T - diag(y^2) @ W)
This naturally leads to weight normalization and PCA extraction.
- Reference:
Oja, E. (1982). Simplified neuron model as a principal component analyzer. Journal of Mathematical Biology, 15(3), 267-273.
Initialize Oja trainer.
- Parameters:
model (src.canns.models.brain_inspired.BrainInspiredModel) – The model to train (typically LinearLayer)
learning_rate (float) – Learning rate η for weight updates
normalize_weights (bool) – Whether to normalize weights to unit norm after each update
weight_attr (str) – Name of model attribute holding the connection weights
compiled (bool) – Whether to use JIT-compiled training loop (default: True)
**kwargs – Additional arguments passed to parent Trainer
- predict(pattern, *args, **kwargs)[source]¶
Predict output for a single input pattern.
- Parameters:
pattern – Input pattern of shape (input_size,)
- Returns:
Output pattern of shape (output_size,)
- train(train_data)[source]¶
Train the model using Oja’s rule.
- Parameters:
train_data (collections.abc.Iterable) – Iterable of input patterns (each of shape (input_size,))
- compiled = True¶
- learning_rate = 0.01¶
- normalize_weights = True¶
- weight_attr = 'W'¶
- class src.canns.trainer.STDPTrainer(model, learning_rate=0.01, A_plus=0.005, A_minus=0.00525, weight_attr='W', w_min=0.0, w_max=1.0, compiled=True, **kwargs)[source]¶
Bases:
src.canns.trainer._base.TrainerSTDP (Spike-Timing-Dependent Plasticity) trainer.
STDP is a biologically-inspired learning rule that adjusts synaptic weights based on the precise timing of pre- and post-synaptic spikes. Synapses are strengthened when pre-synaptic spikes precede post-synaptic spikes (LTP), and weakened when the order is reversed (LTD).
- Trace-based Learning Rule:
ΔW_ij = A_plus * trace_pre[j] * spike_post[i] - A_minus * trace_post[i] * spike_pre[j]
- where:
W_ij is the weight from input j to neuron i
spike_pre[j] is the presynaptic spike (0 or 1)
spike_post[i] is the postsynaptic spike (0 or 1)
trace_pre[j] is the exponential trace of presynaptic spikes
trace_post[i] is the exponential trace of postsynaptic spikes
A_plus controls LTP (long-term potentiation) magnitude
A_minus controls LTD (long-term depression) magnitude
- The spike traces evolve as:
trace = decay * trace + spike
This provides a temporal window for spike-timing correlations.
References
Gerstner & Kistler (2002): Spiking Neuron Models
Morrison et al. (2008): Phenomenological models of synaptic plasticity
Bi & Poo (1998): Synaptic modifications in cultured hippocampal neurons
Initialize STDP trainer.
- Parameters:
model (src.canns.models.brain_inspired.BrainInspiredModel) – The spiking model to train (typically SpikingLayer)
learning_rate (float) – Global learning rate multiplier (default: 0.01)
A_plus (float) – LTP magnitude (default: 0.005)
A_minus (float) – LTD magnitude (default: 0.00525, slightly > A_plus for stability)
weight_attr (str) – Name of model attribute holding the connection weights
w_min (float) – Minimum weight value (default: 0.0 for excitatory synapses)
w_max (float) – Maximum weight value (default: 1.0)
compiled (bool) – Whether to use JIT-compiled training loop (default: True)
**kwargs – Additional arguments passed to parent Trainer
- predict(pattern, *args, **kwargs)[source]¶
Predict output spikes for a single input spike pattern.
- Parameters:
pattern – Input spike pattern of shape (input_size,)
- Returns:
Output spike pattern of shape (output_size,) with binary values (0 or 1)
- train(train_data)[source]¶
Train the model using STDP rule.
- Parameters:
train_data (collections.abc.Iterable) – Iterable of input spike patterns (each of shape (input_size,)) Each pattern should contain binary values (0 or 1)
- A_minus = 0.00525¶
- A_plus = 0.005¶
- compiled = True¶
- learning_rate = 0.01¶
- w_max = 1.0¶
- w_min = 0.0¶
- weight_attr = 'W'¶
- class src.canns.trainer.SangerTrainer(model, learning_rate=0.01, normalize_weights=True, weight_attr='W', compiled=True, **kwargs)[source]¶
Bases:
src.canns.trainer._base.TrainerSanger’s rule (Generalized Hebbian Algorithm) for multiple PC extraction.
Extends Oja’s rule with Gram-Schmidt orthogonalization to extract multiple principal components. Each neuron learns to be orthogonal to all previous ones.
- Learning Rule:
ΔW_i = η * (y_i * x - y_i * Σ_{j≤i} y_j * W_j)
- where:
W_i is the i-th neuron’s weight vector
y = W @ x is the output vector
The sum enforces orthogonality (Gram-Schmidt process)
This allows sequential extraction of orthogonal principal components, with neuron i converging to the i-th principal component.
- Reference:
Sanger, T. D. (1989). Optimal unsupervised learning in a single-layer linear feedforward neural network. Neural Networks, 2(6), 459-473.
Initialize Sanger trainer.
- Parameters:
model (src.canns.models.brain_inspired.BrainInspiredModel) – The model to train (typically LinearLayer)
learning_rate (float) – Learning rate η for weight updates
normalize_weights (bool) – Whether to normalize weights to unit norm after each update
weight_attr (str) – Name of model attribute holding the connection weights
compiled (bool) – Whether to use JIT-compiled training loop (default: True)
**kwargs – Additional arguments passed to parent Trainer
- predict(pattern, *args, **kwargs)[source]¶
Predict output for a single input pattern.
- Parameters:
pattern – Input pattern of shape (input_size,)
- Returns:
Output pattern of shape (output_size,)
- train(train_data)[source]¶
Train the model using Sanger’s rule.
- Parameters:
train_data (collections.abc.Iterable) – Iterable of input patterns (each of shape (input_size,))
- compiled = True¶
- learning_rate = 0.01¶
- normalize_weights = True¶
- weight_attr = 'W'¶
- class src.canns.trainer.Trainer(model=None, *, show_iteration_progress=False, compiled_prediction=True)¶
Bases:
abc.ABCAbstract base class for training utilities in CANNs.
- configure_progress(*, show_iteration_progress=None, compiled_prediction=None)¶
Update progress-related flags for derived trainers.
- abstractmethod predict(pattern, *args, **kwargs)¶
Predict an output for a single pattern.
- predict_batch(patterns, *args, **kwargs)¶
Predict outputs for multiple patterns using
predict.
- abstractmethod train(train_data)¶
Train the associated model with the provided dataset.
- compiled_prediction = True¶
- model = None¶
- show_iteration_progress = False¶