src.canns.trainer.utils

Shared utilities for brain-inspired trainers.

Attributes

Functions

compute_running_average(current_avg, new_value, tau)

Compute exponential running average for BCM sliding thresholds.

initialize_spike_buffer(num_neurons, buffer_size)

Initialize spike time buffer for STDP learning.

normalize_weight_rows(W)

Normalize each row of weight matrix to unit norm (for Oja's rule).

stdp_kernel(dt[, tau_plus, tau_minus])

Compute STDP timing kernel for weight change.

update_spike_buffer(buffer, neuron_idx, spike_time)

Update spike buffer with new spike time (circular buffer).

Module Contents

src.canns.trainer.utils.compute_running_average(current_avg, new_value, tau)[source]

Compute exponential running average for BCM sliding thresholds.

Parameters:
  • current_avg (jax.numpy.ndarray) – Current average value

  • new_value (jax.numpy.ndarray) – New value to incorporate

  • tau (float) – Time constant (higher = slower adaptation)

Returns:

Updated running average

Return type:

jax.numpy.ndarray

src.canns.trainer.utils.initialize_spike_buffer(num_neurons, buffer_size)[source]

Initialize spike time buffer for STDP learning.

Parameters:
  • num_neurons (int) – Number of neurons in the network

  • buffer_size (int) – Number of recent spike times to store per neuron

Returns:

Spike buffer of shape (num_neurons, buffer_size) initialized to -inf

Return type:

jax.numpy.ndarray

src.canns.trainer.utils.normalize_weight_rows(W)[source]

Normalize each row of weight matrix to unit norm (for Oja’s rule).

Parameters:

W (jax.numpy.ndarray) – Weight matrix of shape (N, M)

Returns:

Normalized weight matrix with unit-norm rows

Return type:

jax.numpy.ndarray

src.canns.trainer.utils.stdp_kernel(dt, tau_plus=20.0, tau_minus=20.0)[source]

Compute STDP timing kernel for weight change.

Parameters:
  • dt (float) – Time difference (post_spike_time - pre_spike_time)

  • tau_plus (float) – Time constant for potentiation (dt > 0)

  • tau_minus (float) – Time constant for depression (dt < 0)

Returns:

Weight change magnitude (positive for potentiation, negative for depression)

Return type:

jax.numpy.ndarray

src.canns.trainer.utils.update_spike_buffer(buffer, neuron_idx, spike_time)[source]

Update spike buffer with new spike time (circular buffer).

Parameters:
  • buffer (jax.numpy.ndarray) – Current spike buffer of shape (num_neurons, buffer_size)

  • neuron_idx (int) – Index of neuron that spiked

  • spike_time (float) – Time of spike

Returns:

Updated spike buffer

Return type:

jax.numpy.ndarray

src.canns.trainer.utils.stdp_kernel_vec[source]