src.canns.trainer.utils¶
Shared utilities for brain-inspired trainers.
Attributes¶
Functions¶
|
Compute exponential running average for BCM sliding thresholds. |
|
Initialize spike time buffer for STDP learning. |
Normalize each row of weight matrix to unit norm (for Oja's rule). |
|
|
Compute STDP timing kernel for weight change. |
|
Update spike buffer with new spike time (circular buffer). |
Module Contents¶
- src.canns.trainer.utils.compute_running_average(current_avg, new_value, tau)[source]¶
Compute exponential running average for BCM sliding thresholds.
- Parameters:
current_avg (jax.numpy.ndarray) – Current average value
new_value (jax.numpy.ndarray) – New value to incorporate
tau (float) – Time constant (higher = slower adaptation)
- Returns:
Updated running average
- Return type:
jax.numpy.ndarray
- src.canns.trainer.utils.initialize_spike_buffer(num_neurons, buffer_size)[source]¶
Initialize spike time buffer for STDP learning.
- src.canns.trainer.utils.normalize_weight_rows(W)[source]¶
Normalize each row of weight matrix to unit norm (for Oja’s rule).
- Parameters:
W (jax.numpy.ndarray) – Weight matrix of shape (N, M)
- Returns:
Normalized weight matrix with unit-norm rows
- Return type:
jax.numpy.ndarray
- src.canns.trainer.utils.stdp_kernel(dt, tau_plus=20.0, tau_minus=20.0)[source]¶
Compute STDP timing kernel for weight change.
- Parameters:
- Returns:
Weight change magnitude (positive for potentiation, negative for depression)
- Return type:
jax.numpy.ndarray