Source code for src.canns.trainer.stdp

"""STDP (Spike-Timing-Dependent Plasticity) trainer."""

from __future__ import annotations

from collections.abc import Iterable

import brainpy.math as bm
import jax.numpy as jnp

from ..models.brain_inspired import BrainInspiredModel
from ._base import Trainer

__all__ = ["STDPTrainer"]


[docs] class STDPTrainer(Trainer): """ STDP (Spike-Timing-Dependent Plasticity) trainer. STDP is a biologically-inspired learning rule that adjusts synaptic weights based on the precise timing of pre- and post-synaptic spikes. Synapses are strengthened when pre-synaptic spikes precede post-synaptic spikes (LTP), and weakened when the order is reversed (LTD). Trace-based Learning Rule: ΔW_ij = A_plus * trace_pre[j] * spike_post[i] - A_minus * trace_post[i] * spike_pre[j] where: - W_ij is the weight from input j to neuron i - spike_pre[j] is the presynaptic spike (0 or 1) - spike_post[i] is the postsynaptic spike (0 or 1) - trace_pre[j] is the exponential trace of presynaptic spikes - trace_post[i] is the exponential trace of postsynaptic spikes - A_plus controls LTP (long-term potentiation) magnitude - A_minus controls LTD (long-term depression) magnitude The spike traces evolve as: trace = decay * trace + spike This provides a temporal window for spike-timing correlations. References: - Gerstner & Kistler (2002): Spiking Neuron Models - Morrison et al. (2008): Phenomenological models of synaptic plasticity - Bi & Poo (1998): Synaptic modifications in cultured hippocampal neurons """ def __init__( self, model: BrainInspiredModel, learning_rate: float = 0.01, A_plus: float = 0.005, A_minus: float = 0.00525, weight_attr: str = "W", w_min: float = 0.0, w_max: float = 1.0, compiled: bool = True, **kwargs, ): """ Initialize STDP trainer. Args: model: The spiking model to train (typically SpikingLayer) learning_rate: Global learning rate multiplier (default: 0.01) A_plus: LTP magnitude (default: 0.005) A_minus: LTD magnitude (default: 0.00525, slightly > A_plus for stability) weight_attr: Name of model attribute holding the connection weights w_min: Minimum weight value (default: 0.0 for excitatory synapses) w_max: Maximum weight value (default: 1.0) compiled: Whether to use JIT-compiled training loop (default: True) **kwargs: Additional arguments passed to parent Trainer """ super().__init__(model=model, **kwargs)
[docs] self.learning_rate = learning_rate
[docs] self.A_plus = A_plus
[docs] self.A_minus = A_minus
[docs] self.weight_attr = weight_attr
[docs] self.w_min = w_min
[docs] self.w_max = w_max
[docs] self.compiled = compiled
[docs] def train(self, train_data: Iterable): """ Train the model using STDP rule. Args: train_data: Iterable of input spike patterns (each of shape (input_size,)) Each pattern should contain binary values (0 or 1) """ # Get weight parameter weight_param = getattr(self.model, self.weight_attr, None) if weight_param is None or not hasattr(weight_param, "value"): raise AttributeError( f"Model does not have a '{self.weight_attr}' parameter with .value attribute" ) # Check if model has required trace attributes if not hasattr(self.model, "trace_pre"): raise AttributeError("Model must have 'trace_pre' attribute for STDP learning") if not hasattr(self.model, "trace_post"): raise AttributeError("Model must have 'trace_post' attribute for STDP learning") if self.compiled: self._train_compiled(train_data, weight_param) else: self._train_uncompiled(train_data, weight_param)
def _train_compiled(self, train_data: Iterable, weight_param): """ JIT-compiled training loop using bp.transform.scan. Args: train_data: Iterable of input spike patterns weight_param: Weight parameter object """ # Convert patterns to array for JIT compilation patterns = jnp.stack([jnp.asarray(p, dtype=jnp.float32) for p in train_data]) # Get model parameters trace_decay = getattr(self.model, "trace_decay", 0.95) threshold = getattr(self.model, "threshold", 1.0) v_reset = getattr(self.model, "v_reset", 0.0) leak = getattr(self.model, "leak", 0.9) # Initial state W_init = jnp.asarray(weight_param.value, dtype=jnp.float32) trace_pre_init = jnp.zeros(self.model.input_size, dtype=jnp.float32) trace_post_init = jnp.zeros(self.model.output_size, dtype=jnp.float32) v_init = jnp.zeros(self.model.output_size, dtype=jnp.float32) # Training step for single pattern def train_step(carry, x): W, trace_pre, trace_post, v = carry # Update pre-synaptic trace trace_pre = trace_decay * trace_pre + x # Forward pass (LIF dynamics) input_current = W @ x v = leak * v + input_current # Generate spikes spike_post = (v >= threshold).astype(jnp.float32) # Reset membrane potential v = jnp.where(spike_post > 0, v_reset, v) # Update post-synaptic trace trace_post = trace_decay * trace_post + spike_post # STDP weight update # LTP: pre before post (pre trace high, post spike now) ltp = self.A_plus * jnp.outer(spike_post, trace_pre) # LTD: post before pre (post trace high, pre spike now) ltd = self.A_minus * jnp.outer(trace_post, x) delta_W = self.learning_rate * (ltp - ltd) W = W + delta_W # Clip weights to valid range W = jnp.clip(W, self.w_min, self.w_max) return (W, trace_pre, trace_post, v), None # Run compiled scan (W_final, trace_pre_final, trace_post_final, v_final), _ = bm.scan( train_step, (W_init, trace_pre_init, trace_post_init, v_init), patterns ) # Update model parameters weight_param.value = W_final self.model.trace_pre.value = trace_pre_final self.model.trace_post.value = trace_post_final self.model.v.value = v_final def _train_uncompiled(self, train_data: Iterable, weight_param): """ Python loop training (fallback, slower but more flexible). Args: train_data: Iterable of input spike patterns weight_param: Weight parameter object """ W = weight_param.value # Process each pattern for pattern in train_data: x = jnp.asarray(pattern, dtype=jnp.float32) # Store current traces before forward pass trace_pre_before = self.model.trace_pre.value trace_post_before = self.model.trace_post.value # Forward pass through model (updates traces and generates spikes) spike_post = self.model.forward(x) # STDP weight update # LTP: pre before post (use pre trace from before post spike) ltp = self.A_plus * jnp.outer(spike_post, trace_pre_before) # LTD: post before pre (use post trace from before pre spike) ltd = self.A_minus * jnp.outer(trace_post_before, x) delta_W = self.learning_rate * (ltp - ltd) W = W + delta_W # Clip weights to valid range W = jnp.clip(W, self.w_min, self.w_max) # Update model weights weight_param.value = W
[docs] def predict(self, pattern, *args, **kwargs): """ Predict output spikes for a single input spike pattern. Args: pattern: Input spike pattern of shape (input_size,) Returns: Output spike pattern of shape (output_size,) with binary values (0 or 1) """ if hasattr(self.model, "forward"): return self.model.forward(pattern) else: # Fallback: direct computation with thresholding weight_param = getattr(self.model, self.weight_attr) x = jnp.asarray(pattern, dtype=jnp.float32) v = weight_param.value @ x threshold = getattr(self.model, "threshold", 1.0) return (v >= threshold).astype(jnp.float32)