src.canns.trainer.hebbian¶
Classes¶
Anti-Hebbian trainer for pattern decorrelation and unlearning. |
|
Generic Hebbian trainer with progress reporting. |
Module Contents¶
- class src.canns.trainer.hebbian.AntiHebbianTrainer(model, **kwargs)[source]¶
Bases:
HebbianTrainerAnti-Hebbian trainer for pattern decorrelation and unlearning.
Overview - Implements anti-Hebbian learning rule: “Neurons that fire together, wire apart” - Uses negative weight updates:
W <- W - Σ (t t^T)instead of positive - Inherits all functionality from HebbianTrainer (predict, predict_batch, etc.)Applications - Sparse coding and independent component analysis - Competitive learning networks - Decorrelation and whitening of feature representations - Lateral inhibition modeling - Selective forgetting / pattern unlearning
Learning Rule - For patterns
x, compute optional mean activityrhoand update:W <- W - sum_i (x_i - rho)(x_i - rho)^T(note the minus sign)If
subtract_mean=True, patterns are centered by mean:t = x - rhoIf
normalize_by_patterns=True, divide by number of patternsAll options from HebbianTrainer apply (subtract_mean, zero_diagonal, etc.)
- Example
>>> model = AmariHopfieldNetwork(num_neurons=100, activation="tanh") >>> # Train with Hebbian first >>> hebb_trainer = HebbianTrainer(model) >>> hebb_trainer.train(all_patterns) >>> # Then apply anti-Hebbian to unlearn specific pattern >>> anti_trainer = AntiHebbianTrainer(model, subtract_mean=False) >>> anti_trainer.train([pattern_to_forget])
Initialize Anti-Hebbian trainer.
- Parameters:
model (src.canns.models.brain_inspired.BrainInspiredModel) – The model to train
**kwargs – Additional arguments passed to HebbianTrainer
- class src.canns.trainer.hebbian.HebbianTrainer(model, show_iteration_progress=False, compiled_prediction=True, *, weight_attr='W', subtract_mean=True, zero_diagonal=True, normalize_by_patterns=True, prefer_generic=True, state_attr=None, prefer_generic_predict=True, preserve_on_resize=True)[source]¶
Bases:
src.canns.trainer._base.TrainerGeneric Hebbian trainer with progress reporting.
Overview - Uses a model-exposed weight parameter (default attribute name:
W) to apply astandard Hebbian update. If unavailable, falls back to the model’s
apply_hebbian_learning.Works with models that expose a parameter object with a
.valuendarray of shape (N, N) (e.g.,bm.Variable).
Generic rule - For patterns
x(shape: (N,)), compute optional mean activityrhoand updateW <- W + sum_i (x_i - rho)(x_i - rho)^T.Options allow zeroing the diagonal and normalizing by number of patterns.
Key options -
weight_attr: Name of the weight attribute on the model (default: “W”). -subtract_mean: Whether to center patterns by mean activityrho. -zero_diagonal: Whether to set diagonal ofWto zero after update. -normalize_by_patterns: Divide accumulated outer-products by number of patterns. -prefer_generic: Prefer the generic Hebbian rule over model-specific method. -state_attr: Name of the state vector attribute for prediction (default:s; ormodel-provided
predict_state_attr).prefer_generic_predict: Prefer the trainer’s generic predict loop over the model’spredictimplementation (falls back automatically when unsupported).
Initialize Hebbian trainer.
- Parameters:
model (src.canns.models.brain_inspired.BrainInspiredModel) – The model to train
show_iteration_progress (bool) – Whether to show progress for individual pattern convergence
compiled_prediction (bool) – Whether to use compiled prediction by default (faster but no iteration progress)
weight_attr (str | None) – Name of model attribute holding the connection weights (default: “W”).
subtract_mean (bool) – Subtract dataset mean activity (rho) before outer-product.
zero_diagonal (bool) – Force zero self-connections after update.
normalize_by_patterns (bool) – Divide accumulated outer-products by number of patterns.
prefer_generic (bool) – If True, use trainer’s generic Hebbian rule when possible; otherwise call the model’s own implementation if available.
- predict(pattern, num_iter=20, compiled=None, show_progress=None, convergence_threshold=1e-10, progress_callback=None)[source]¶
Predict a single pattern.
- Parameters:
- Returns:
Predicted pattern
- predict_batch(patterns, num_iter=20, compiled=None, show_sample_progress=True, show_iteration_progress=None, convergence_threshold=1e-10)[source]¶
Predict multiple patterns with progress reporting.
- Parameters:
patterns (list) – List of input patterns to predict
num_iter (int) – Maximum number of iterations per pattern
compiled (bool | None) – Override default compiled setting
show_sample_progress (bool) – Whether to show progress across samples
show_iteration_progress (bool | None) – Override default iteration progress setting
convergence_threshold (float) – Energy change threshold for convergence
- Returns:
List of predicted patterns
- Return type:
- train(train_data)[source]¶
Train the model using Hebbian learning.
Behavior - Preferred path: apply a generic Hebbian update directly to
model.<weight_attr>. - Fallback path: callmodel.apply_hebbian_learning(train_data)if generic pathis unavailable.
Requirements for generic path - Model must expose
model.<weight_attr>with a.valuearray of shape (N, N). - Optionally, models can declareweight_attrproperty to specify theattribute name, allowing
HebbianTrainer(..., weight_attr=None).