Source code for src.canns.trainer.utils
"""Shared utilities for brain-inspired trainers."""
from __future__ import annotations
import jax.numpy as jnp
[docs]
def compute_running_average(
current_avg: jnp.ndarray, new_value: jnp.ndarray, tau: float
) -> jnp.ndarray:
"""
Compute exponential running average for BCM sliding thresholds.
Args:
current_avg: Current average value
new_value: New value to incorporate
tau: Time constant (higher = slower adaptation)
Returns:
Updated running average
"""
alpha = 1.0 / tau if tau > 0 else 1.0
return (1.0 - alpha) * current_avg + alpha * new_value
[docs]
def normalize_weight_rows(W: jnp.ndarray) -> jnp.ndarray:
"""
Normalize each row of weight matrix to unit norm (for Oja's rule).
Args:
W: Weight matrix of shape (N, M)
Returns:
Normalized weight matrix with unit-norm rows
"""
norms = jnp.linalg.norm(W, axis=1, keepdims=True)
# Avoid division by zero
norms = jnp.where(norms < 1e-10, 1.0, norms)
return W / norms
[docs]
def initialize_spike_buffer(num_neurons: int, buffer_size: int) -> jnp.ndarray:
"""
Initialize spike time buffer for STDP learning.
Args:
num_neurons: Number of neurons in the network
buffer_size: Number of recent spike times to store per neuron
Returns:
Spike buffer of shape (num_neurons, buffer_size) initialized to -inf
"""
return jnp.full((num_neurons, buffer_size), -jnp.inf, dtype=jnp.float32)
[docs]
def update_spike_buffer(buffer: jnp.ndarray, neuron_idx: int, spike_time: float) -> jnp.ndarray:
"""
Update spike buffer with new spike time (circular buffer).
Args:
buffer: Current spike buffer of shape (num_neurons, buffer_size)
neuron_idx: Index of neuron that spiked
spike_time: Time of spike
Returns:
Updated spike buffer
"""
# Roll buffer left and insert new spike time at end
neuron_buffer = buffer[neuron_idx]
new_buffer = jnp.roll(neuron_buffer, -1)
new_buffer = new_buffer.at[-1].set(spike_time)
return buffer.at[neuron_idx].set(new_buffer)
[docs]
def stdp_kernel(dt: float, tau_plus: float = 20.0, tau_minus: float = 20.0) -> jnp.ndarray:
"""
Compute STDP timing kernel for weight change.
Args:
dt: Time difference (post_spike_time - pre_spike_time)
tau_plus: Time constant for potentiation (dt > 0)
tau_minus: Time constant for depression (dt < 0)
Returns:
Weight change magnitude (positive for potentiation, negative for depression)
"""
if dt > 0:
# Potentiation: pre before post
return jnp.exp(-dt / tau_plus)
else:
# Depression: post before pre
return -jnp.exp(dt / tau_minus)
# Vectorized version for batch processing
[docs]
stdp_kernel_vec = jnp.vectorize(stdp_kernel, excluded=[1, 2])