from __future__ import annotations
from collections.abc import Callable, Iterable
import brainpy.math as bm
import jax
import jax.numpy as jnp
from tqdm import tqdm # type: ignore
from ..models.brain_inspired import BrainInspiredModel
from ._base import Trainer
__all__ = ["HebbianTrainer", "AntiHebbianTrainer"]
[docs]
class HebbianTrainer(Trainer):
"""
Generic Hebbian trainer with progress reporting.
Overview
- Uses a model-exposed weight parameter (default attribute name: ``W``) to apply a
standard Hebbian update. If unavailable, falls back to the model's
``apply_hebbian_learning``.
- Works with models that expose a parameter object with a ``.value`` ndarray of shape
(N, N) (e.g., ``bm.Variable``).
Generic rule
- For patterns ``x`` (shape: (N,)), compute optional mean activity ``rho`` and update
``W <- W + sum_i (x_i - rho)(x_i - rho)^T``.
- Options allow zeroing the diagonal and normalizing by number of patterns.
Key options
- ``weight_attr``: Name of the weight attribute on the model (default: "W").
- ``subtract_mean``: Whether to center patterns by mean activity ``rho``.
- ``zero_diagonal``: Whether to set diagonal of ``W`` to zero after update.
- ``normalize_by_patterns``: Divide accumulated outer-products by number of patterns.
- ``prefer_generic``: Prefer the generic Hebbian rule over model-specific method.
- ``state_attr``: Name of the state vector attribute for prediction (default: ``s``; or
model-provided ``predict_state_attr``).
- ``prefer_generic_predict``: Prefer the trainer's generic predict loop over the
model's ``predict`` implementation (falls back automatically when unsupported).
"""
def __init__(
self,
model: BrainInspiredModel,
show_iteration_progress: bool = False, # Default to False for cleaner display
compiled_prediction: bool = True,
*,
# Generic Hebbian options
weight_attr: str | None = "W",
subtract_mean: bool = True,
zero_diagonal: bool = True,
normalize_by_patterns: bool = True,
prefer_generic: bool = True,
# Generic predict options
state_attr: str | None = None,
prefer_generic_predict: bool = True,
preserve_on_resize: bool = True,
):
"""
Initialize Hebbian trainer.
Args:
model: The model to train
show_iteration_progress: Whether to show progress for individual pattern convergence
compiled_prediction: Whether to use compiled prediction by default (faster but no iteration progress)
weight_attr: Name of model attribute holding the connection weights (default: "W").
subtract_mean: Subtract dataset mean activity (rho) before outer-product.
zero_diagonal: Force zero self-connections after update.
normalize_by_patterns: Divide accumulated outer-products by number of patterns.
prefer_generic: If True, use trainer's generic Hebbian rule when possible; otherwise
call the model's own implementation if available.
"""
super().__init__(
model=model,
show_iteration_progress=show_iteration_progress,
compiled_prediction=compiled_prediction,
)
# Generic Hebbian config
[docs]
self.weight_attr = weight_attr
[docs]
self.subtract_mean = subtract_mean
[docs]
self.zero_diagonal = zero_diagonal
[docs]
self.normalize_by_patterns = normalize_by_patterns
[docs]
self.prefer_generic = prefer_generic
# Generic predict config
[docs]
self.state_attr = state_attr
[docs]
self.prefer_generic_predict = prefer_generic_predict
[docs]
self.preserve_on_resize = preserve_on_resize
[docs]
def train(self, train_data: Iterable):
"""
Train the model using Hebbian learning.
Behavior
- Preferred path: apply a generic Hebbian update directly to ``model.<weight_attr>``.
- Fallback path: call ``model.apply_hebbian_learning(train_data)`` if generic path
is unavailable.
Requirements for generic path
- Model must expose ``model.<weight_attr>`` with a ``.value`` array of shape (N, N).
- Optionally, models can declare ``weight_attr`` property to specify the
attribute name, allowing ``HebbianTrainer(..., weight_attr=None)``.
"""
used_generic = False
# Materialize training data (avoid consuming generators twice)
patterns = [jnp.asarray(p) for p in train_data]
if len(patterns) == 0:
return
# Determine the weight attribute to use (allow model override via `weight_attr`)
weight_attr = self.weight_attr
if weight_attr is None and hasattr(self.model, "weight_attr"):
try:
weight_attr = self.model.weight_attr # could be property/str
if callable(weight_attr):
weight_attr = weight_attr()
except Exception:
weight_attr = None
# Ensure model dimensionality matches training patterns (use first pattern)
n = int(jnp.asarray(patterns[0]).shape[0])
self._ensure_model_dim(n, weight_attr)
# Try generic path if preferred
if self.prefer_generic and weight_attr is not None:
param = getattr(self.model, weight_attr, None)
if param is not None and hasattr(param, "value"):
W = param.value
if (
W is not None
and hasattr(W, "shape")
and len(W.shape) == 2
and W.shape[0] == W.shape[1]
):
self._apply_generic_hebbian(patterns, param)
used_generic = True
# Fallback to model-specific implementation if generic path wasn't used
if not used_generic:
if hasattr(self.model, "apply_hebbian_learning"):
self.model.apply_hebbian_learning(patterns)
else:
raise AttributeError(
"Model does not expose a suitable weight attribute for generic Hebbian "
"learning and has no `apply_hebbian_learning` method."
)
def _compute_weight_update(self, patterns: list[jnp.ndarray], sign: float = 1.0) -> jnp.ndarray:
"""
Compute weight update from patterns using vectorized JAX operations.
This method extracts the common logic for computing Hebbian-style updates,
used by both HebbianTrainer (+) and AntiHebbianTrainer (-).
Args:
patterns: List of 1D JAX arrays of shape (N,)
sign: Multiplicative factor (+1.0 for Hebbian, -1.0 for Anti-Hebbian)
Returns:
Weight update matrix of shape (N, N)
"""
# Stack patterns into array (P, N) where P is number of patterns
patterns_array = jnp.stack(patterns, axis=0)
num_patterns, n = patterns_array.shape
# Compute mean activity across all patterns if requested
if self.subtract_mean:
rho = jnp.mean(patterns_array)
patterns_array = patterns_array - rho
# Vectorized outer product computation using vmap
# Maps outer product over each pattern: (P, N) -> (P, N, N)
outer_products = jax.vmap(lambda p: jnp.outer(p, p))(patterns_array)
# Sum across patterns: (P, N, N) -> (N, N)
W_accum = jnp.sum(outer_products, axis=0)
# Normalize by number of patterns if requested
if self.normalize_by_patterns:
W_accum = W_accum / num_patterns
# Apply sign for Hebbian (+) or Anti-Hebbian (-)
return sign * W_accum
def _apply_generic_hebbian(self, train_data: Iterable, weight_param) -> None:
"""
Apply generic Hebbian learning.
Rule
- ``W <- W + Σ (t t^T)`` where ``t = x - rho`` if centering enabled, otherwise ``t = x``.
- If ``normalize_by_patterns`` is True, divide by number of patterns.
- If ``zero_diagonal`` is True, set diagonal to zero after update.
Args
- train_data: Iterable of 1D patterns (numpy/jax arrays) of shape (N,).
- weight_param: Parameter object with ``.value`` as ndarray (N, N).
"""
# Gather patterns as jax arrays
patterns = [jnp.asarray(p, dtype=jnp.float32) for p in train_data]
if not patterns:
return
# Validate pattern dimensions
n = patterns[0].shape[0]
for p in patterns:
if p.ndim != 1 or p.shape[0] != n:
raise ValueError("All patterns must be 1D with consistent length.")
# Compute weight update using vectorized operations
W_accum = self._compute_weight_update(patterns, sign=1.0)
# Update with existing weights (Hebbian: addition)
W_new = jnp.asarray(weight_param.value, dtype=jnp.float32) + W_accum
# Force zero diagonal if required
if self.zero_diagonal:
W_new = W_new.at[jnp.diag_indices(W_new.shape[0])].set(0.0)
weight_param.value = W_new
[docs]
def predict(
self,
pattern,
num_iter: int = 20,
compiled: bool | None = None,
show_progress: bool | None = None,
convergence_threshold: float = 1e-10,
progress_callback: Callable[[int, float, bool, float], None] | None = None,
):
"""
Predict a single pattern.
Args:
pattern: Input pattern to predict
num_iter: Maximum number of iterations
compiled: Override default compiled setting
show_progress: Override default progress setting
convergence_threshold: Energy change threshold for convergence
Returns:
Predicted pattern
"""
# Always use compiled path; ignore `compiled` and iteration progress flags.
# Keep parameters for backward compatibility.
compiled = True
if show_progress is None:
show_progress = False
# Create progress bar callback if needed
bar_callback = None
pbar = None
if show_progress and not compiled:
pbar = tqdm(total=num_iter, desc="Converging", ncols=80, leave=False)
def bar_callback(iteration, energy, converged, energy_change):
# Update with simpler format to avoid clutter
status_icon = "✓" if converged else "→"
energy_str = f"{energy:.0f}" if abs(energy) > 1000 else f"{energy:.3f}"
pbar.set_postfix(E=energy_str, st=status_icon)
pbar.update(1)
if converged:
# Fill remaining iterations to show completion
remaining = num_iter - iteration
if remaining > 0:
pbar.update(remaining)
# Always use generic predict (no backward-compat call to model.predict)
# Check capability: model must have update, energy, and a vector state attr
state_attr = self._resolve_state_attr()
state_param = getattr(self.model, state_attr, None)
if not (
hasattr(self.model, "update")
and hasattr(self.model, "energy")
and state_param is not None
and hasattr(state_param, "value")
):
raise AttributeError(
"Generic prediction requires model.update, model.energy, and a state vector "
f"attribute '{state_attr}' with a '.value' array."
)
# Initialize state
# Ensure dimensionality matches pattern
n = int(jnp.asarray(pattern).shape[0])
self._ensure_model_dim(
n,
self.weight_attr
if self.weight_attr is not None
else getattr(self.model, "weight_attr", None),
)
# Refresh state_param after potential resize
state_attr = self._resolve_state_attr()
state_param = getattr(self.model, state_attr, None)
self._set_state_vector(pattern, state_param)
# Prepare combined callback
def combined_callback(iteration, energy, converged, energy_change):
if progress_callback is not None:
try:
progress_callback(iteration, energy, converged, energy_change)
except Exception:
pass
if bar_callback is not None:
bar_callback(iteration, energy, converged, energy_change)
# Run (always compiled path)
try:
result = self._predict_generic_compiled(num_iter, state_param)
finally:
if pbar is not None:
pbar.close()
return result
def _ensure_model_dim(self, n: int, weight_attr: str | None):
"""Ensure model parameter/state dimensionality equals n.
Tries model.resize if available; otherwise, adjusts weight/state arrays directly.
"""
# Prefer model-provided resize
if hasattr(self.model, "resize"):
try:
self.model.resize(n, preserve_submatrix=self.preserve_on_resize)
return
except Exception:
pass
# Fallback: attempt to resize arrays
if weight_attr is not None:
param = getattr(self.model, weight_attr, None)
if param is not None and hasattr(param, "value"):
W = jnp.asarray(param.value)
if W.ndim != 2 or W.shape[0] != n or W.shape[1] != n:
new_W = jnp.zeros((n, n), dtype=jnp.float32)
if self.preserve_on_resize and W.ndim == 2:
m0 = min(W.shape[0], n)
m1 = min(W.shape[1], n)
new_W = new_W.at[:m0, :m1].set(W[:m0, :m1])
param.value = new_W - jnp.diag(jnp.diag(new_W))
# State vector
state_attr = self._resolve_state_attr()
state_param = getattr(self.model, state_attr, None)
if state_param is not None and hasattr(state_param, "value"):
s = jnp.asarray(state_param.value)
if s.ndim != 1 or s.shape[0] != n:
state_param.value = jnp.ones((n,), dtype=jnp.float32)
def _resolve_state_attr(self) -> str:
"""
Resolve the name of the state attribute to use for predictions.
Checks in order:
1. Explicit state_attr parameter from constructor
2. Model's predict_state_attr hint (method or property)
3. Default "s"
Returns:
str: Name of the state attribute (e.g., "s", "u", "r")
"""
# Explicit override takes precedence
if self.state_attr is not None:
return self.state_attr
# Model-provided hint
if hasattr(self.model, "predict_state_attr"):
try:
attr = self.model.predict_state_attr
if callable(attr):
attr = attr()
if isinstance(attr, str):
return attr
except Exception:
pass
# Default
return "s"
def _set_state_vector(self, pattern, state_param) -> None:
"""
Set model state vector from a pattern array.
Args:
pattern: Input pattern to set as state
state_param: State parameter object with .value attribute
"""
vec = jnp.asarray(pattern, dtype=jnp.float32)
state_param.value = vec
def _get_state_vector(self, state_param):
"""
Get current model state as a JAX array.
Args:
state_param: State parameter object with .value attribute
Returns:
jnp.ndarray: Current state vector
"""
return jnp.asarray(state_param.value, dtype=jnp.float32)
def _predict_generic_compiled(self, num_iter: int, state_param):
"""
Run prediction with JAX-compiled for loop for maximum performance.
Uses bm.for_loop for efficient execution on GPU/TPU.
No early stopping or progress tracking for compilation compatibility.
Args:
num_iter: Fixed number of iterations to run
state_param: State parameter to update
Returns:
Final state vector after num_iter iterations
"""
# Initial energy
initial_energy = jnp.float32(self.model.energy)
def step_fn(prev_energy, i):
# Single update step with previous energy
self.model.update(prev_energy)
# Return new energy for next iteration (carry) and None (output)
new_energy = jnp.float32(self.model.energy)
return new_energy, None
# Run scan (modifies model state in-place, carries energy forward)
final_energy, _ = bm.scan(step_fn, initial_energy, bm.arange(num_iter))
# Return final state
return self._get_state_vector(state_param)
def _predict_generic_uncompiled(
self,
num_iter: int,
progress_callback,
convergence_threshold: float,
state_param,
):
"""
Run prediction with Python loop allowing early stopping and progress tracking.
Uses standard Python for loop enabling convergence checks and callbacks.
Less efficient than compiled version but provides more control and feedback.
Args:
num_iter: Maximum number of iterations
progress_callback: Optional callback(iter, energy, converged, delta)
convergence_threshold: Energy change threshold for early stopping
state_param: State parameter to update
Returns:
Final state vector (may stop early if converged)
"""
prev_energy = float(self.model.energy)
for iteration in range(num_iter):
self.model.update(prev_energy)
current_energy = float(self.model.energy)
energy_change = abs(current_energy - prev_energy)
converged = energy_change < convergence_threshold
if progress_callback is not None:
progress_callback(iteration + 1, current_energy, converged, energy_change)
if converged:
break
prev_energy = current_energy
return self._get_state_vector(state_param)
[docs]
def predict_batch(
self,
patterns: list,
num_iter: int = 20,
compiled: bool | None = None,
show_sample_progress: bool = True,
show_iteration_progress: bool | None = None,
convergence_threshold: float = 1e-10,
) -> list:
"""
Predict multiple patterns with progress reporting.
Args:
patterns: List of input patterns to predict
num_iter: Maximum number of iterations per pattern
compiled: Override default compiled setting
show_sample_progress: Whether to show progress across samples
show_iteration_progress: Override default iteration progress setting
convergence_threshold: Energy change threshold for convergence
Returns:
List of predicted patterns
"""
# Always use compiled path (ignore flags)
compiled = True
show_iteration_progress = False
results = []
# Create sample-level progress bar
sample_pbar = None
if show_sample_progress:
sample_pbar = tqdm(total=len(patterns), desc="Processing samples", ncols=80, leave=True)
try:
for i, pattern in enumerate(patterns):
# Predict single pattern
result = self.predict(
pattern,
num_iter=num_iter,
compiled=compiled,
show_progress=show_iteration_progress,
convergence_threshold=convergence_threshold,
)
results.append(result)
# Update sample progress
if sample_pbar is not None:
sample_pbar.set_postfix(sample=f"{i + 1}/{len(patterns)}")
sample_pbar.update(1)
finally:
if sample_pbar is not None:
sample_pbar.close()
return results
[docs]
class AntiHebbianTrainer(HebbianTrainer):
"""
Anti-Hebbian trainer for pattern decorrelation and unlearning.
Overview
- Implements anti-Hebbian learning rule: "Neurons that fire together, wire apart"
- Uses negative weight updates: ``W <- W - Σ (t t^T)`` instead of positive
- Inherits all functionality from HebbianTrainer (predict, predict_batch, etc.)
Applications
- Sparse coding and independent component analysis
- Competitive learning networks
- Decorrelation and whitening of feature representations
- Lateral inhibition modeling
- Selective forgetting / pattern unlearning
Learning Rule
- For patterns ``x``, compute optional mean activity ``rho`` and update:
``W <- W - sum_i (x_i - rho)(x_i - rho)^T`` (note the minus sign)
- If ``subtract_mean=True``, patterns are centered by mean: ``t = x - rho``
- If ``normalize_by_patterns=True``, divide by number of patterns
- All options from HebbianTrainer apply (subtract_mean, zero_diagonal, etc.)
Example
>>> model = AmariHopfieldNetwork(num_neurons=100, activation="tanh")
>>> # Train with Hebbian first
>>> hebb_trainer = HebbianTrainer(model)
>>> hebb_trainer.train(all_patterns)
>>> # Then apply anti-Hebbian to unlearn specific pattern
>>> anti_trainer = AntiHebbianTrainer(model, subtract_mean=False)
>>> anti_trainer.train([pattern_to_forget])
"""
def __init__(self, model: BrainInspiredModel, **kwargs):
"""
Initialize Anti-Hebbian trainer.
Args:
model: The model to train
**kwargs: Additional arguments passed to HebbianTrainer
"""
super().__init__(model, **kwargs)
def _apply_generic_hebbian(self, train_data: Iterable, weight_param) -> None:
"""
Apply anti-Hebbian learning rule.
Rule
- ``W <- W - Σ (t t^T)`` where ``t = x - rho`` if centering enabled, otherwise ``t = x``
- Note the negative sign - this is the key difference from Hebbian learning
- If ``normalize_by_patterns`` is True, divide by number of patterns
- If ``zero_diagonal`` is True, set diagonal to zero after update
Args
- train_data: Iterable of 1D patterns (numpy/jax arrays) of shape (N,)
- weight_param: Parameter object with ``.value`` as ndarray (N, N)
"""
# Gather patterns as jax arrays
patterns = [jnp.asarray(p, dtype=jnp.float32) for p in train_data]
if not patterns:
return
# Validate pattern dimensions
n = patterns[0].shape[0]
for p in patterns:
if p.ndim != 1 or p.shape[0] != n:
raise ValueError("All patterns must be 1D with consistent length.")
# Compute weight update using vectorized operations (negative sign for anti-Hebbian)
W_accum = self._compute_weight_update(patterns, sign=-1.0)
# Update with existing weights (Anti-Hebbian: addition of negative update)
W_new = jnp.asarray(weight_param.value, dtype=jnp.float32) + W_accum
# Force zero diagonal if required
if self.zero_diagonal:
W_new = W_new.at[jnp.diag_indices(W_new.shape[0])].set(0.0)
weight_param.value = W_new