src.canns.models.brain_inspired

Brain-inspired neural network models.

This module contains biologically plausible neural network models that incorporate principles from neuroscience and cognitive science, including associative memory, Hebbian learning, and other brain-inspired mechanisms.

Submodules

Classes

AmariHopfieldNetwork

Amari-Hopfield Network implementation supporting both discrete and continuous dynamics.

BrainInspiredModel

Base class for brain-inspired models.

BrainInspiredModelGroup

Base class for groups of brain-inspired models.

LinearLayer

Generic linear feedforward layer supporting multiple brain-inspired learning rules.

SpikingLayer

Simple Leaky Integrate-and-Fire (LIF) spiking neuron layer.

Package Contents

class src.canns.models.brain_inspired.AmariHopfieldNetwork(num_neurons, asyn=False, threshold=0.0, activation='sign', temperature=1.0, **kwargs)[source]

Bases: src.canns.models.brain_inspired._base.BrainInspiredModel

Amari-Hopfield Network implementation supporting both discrete and continuous dynamics.

This class implements Hopfield networks with flexible activation functions, supporting both discrete binary states and continuous dynamics. The network performs pattern completion through energy minimization using asynchronous or synchronous updates.

The network energy function: E = -0.5 * Σ_ij W_ij * s_i * s_j

Where s_i can be discrete {-1, +1} or continuous depending on activation function.

Reference:

Amari, S. (1977). Neural theory of association and concept-formation. Biological Cybernetics, 26(3), 175-185.

Hopfield, J. J. (1982). Neural networks and physical systems with emergent collective computational abilities. Proceedings of the National Academy of Sciences of the USA, 79(8), 2554-2558.

Initialize the Amari-Hopfield Network.

Parameters:
  • num_neurons (int) – Number of neurons in the network

  • asyn (bool) – Whether to run asynchronously or synchronously

  • threshold (float) – Threshold for activation function

  • activation (str) – Activation function type (“sign”, “tanh”, “sigmoid”)

  • temperature (float) – Temperature parameter for continuous activations

  • **kwargs – Additional arguments passed to parent class

compute_overlap(pattern1, pattern2)[source]

Compute overlap between two binary patterns.

Parameters:
  • pattern1 – Binary patterns to compare

  • pattern2 – Binary patterns to compare

Returns:

Overlap value (1 for identical, 0 for orthogonal, -1 for opposite)

resize(num_neurons, preserve_submatrix=True)[source]

Resize the network dimension and state/weights.

Parameters:
  • num_neurons (int) – New neuron count (N)

  • preserve_submatrix (bool) – If True, copy the top-left min(old, N) block of W into the new matrix; otherwise reinitialize W with zeros.

update(e_old)[source]

Update network state for one time step.

W
activation
asyn = False
property energy

Compute the energy of the network state.

num_neurons
s
property storage_capacity

Get theoretical storage capacity.

Returns:

Theoretical storage capacity (approximately N/(4*ln(N)))

temperature = 1.0
threshold = 0.0
class src.canns.models.brain_inspired.BrainInspiredModel

Bases: src.canns.models.basic._base.BasicModel

Base class for brain-inspired models.

Trainer compatibility notes - If a model wants to support generic Hebbian training, expose a weight parameter

attribute with a .value array of shape (N, N) (commonly a bm.Variable). The recommended attribute name is W.

  • Override weight_attr to declare a different attribute name if needed. Models that use standard backprop may omit this entirely.

  • Implementing apply_hebbian_learning is optional; prefer letting the trainer handle the generic rule when applicable. Implement this only when you need model-specific behavior.

Notes on Predict compatibility - For the trainer’s generic prediction path, models typically expose:

  1. an update(prev_energy) method to advance one step (optional; not all models require energy-driven updates),

  2. an energy property to compute current energy (scalar-like),

  3. a state vector attribute (default s) with .value as 1D array used as the prediction state; override predict_state_attr to change the name.

Optional resizing - Models may implement resize(num_neurons: int, preserve_submatrix: bool = True) to

allow trainers to change neuron dimensionality on the fly (e.g., when training with patterns of a different length). When implemented, the trainer will call this to align dimensions before training/prediction.

abstractmethod apply_hebbian_learning(train_data)

Optional model-specific Hebbian learning implementation.

The generic HebbianTrainer can update W directly without requiring this method. Only implement when custom behavior deviates from the generic rule.

abstractmethod predict(pattern)
abstractmethod resize(num_neurons, preserve_submatrix=True)

Optional method to resize model state/parameters to num_neurons.

Default implementation is a stub. Subclasses may override to support dynamic dimensionality changes.

property energy: float
Abstractmethod:

Current energy of the model state (used for convergence checks in prediction).

Implementations may return a float or a 0-dim array; the trainer treats it as a scalar.

property predict_state_attr: str

Name of the state vector attribute used by generic prediction.

Override in subclasses if the prediction state is not stored in s.

property weight_attr: str

Name of the connection weight attribute used by generic training.

Override in subclasses if the weight parameter is not named W.

class src.canns.models.brain_inspired.BrainInspiredModelGroup

Bases: src.canns.models.basic._base.BasicModelGroup

Base class for groups of brain-inspired models.

This class manages collections of brain-inspired models and provides coordinated learning and dynamics across multiple model instances.

class src.canns.models.brain_inspired.LinearLayer(input_size, output_size, use_bcm_threshold=False, threshold_tau=100.0, **kwargs)[source]

Bases: src.canns.models.brain_inspired._base.BrainInspiredModel

Generic linear feedforward layer supporting multiple brain-inspired learning rules.

This model provides a simple linear transformation with optional sliding threshold for BCM-style plasticity. It can be used with various trainers: - OjaTrainer: Normalized Hebbian learning for PCA - BCMTrainer: Sliding threshold plasticity (requires use_bcm_threshold=True) - HebbianTrainer: Standard Hebbian learning

Computation:

y = W @ x

where W is the weight matrix, x is the input, and y is the output.

For BCM learning, an optional sliding threshold θ tracks output activity:

θ ← θ + (1/τ) * (y² - θ)

References

  • Oja (1982): Simplified neuron model as a principal component analyzer

  • Bienenstock et al. (1982): Theory for the development of neuron selectivity

Initialize the linear layer.

Parameters:
  • input_size (int) – Dimensionality of input vectors

  • output_size (int) – Number of output neurons (features to extract)

  • use_bcm_threshold (bool) – Whether to maintain sliding threshold for BCM learning

  • threshold_tau (float) – Time constant for threshold sliding average (only used if use_bcm_threshold=True)

  • **kwargs – Additional arguments passed to parent class

forward(x)[source]

Forward pass through the layer.

Parameters:

x (jax.numpy.ndarray) – Input vector of shape (input_size,)

Returns:

Output vector of shape (output_size,)

Return type:

jax.numpy.ndarray

resize(input_size, output_size=None, preserve_submatrix=True)[source]

Resize layer dimensions.

Parameters:
  • input_size (int) – New input dimension

  • output_size (int | None) – New output dimension (if None, keep current)

  • preserve_submatrix (bool) – Whether to preserve existing weights

update(prev_energy)[source]

Update method for trainer compatibility (no-op for feedforward layer).

update_threshold()[source]

Update the sliding threshold based on recent activity (BCM only).

This method should be called by BCMTrainer after each forward pass. Updates θ using: θ ← θ + (1/τ) * (y² - θ)

W
property energy: float

Energy for trainer compatibility (0 for feedforward layer).

input_size
output_size
property predict_state_attr: str

Name of output state for prediction.

threshold_tau = 100.0
use_bcm_threshold = False
property weight_attr: str

Name of weight parameter for generic training.

x
y
class src.canns.models.brain_inspired.SpikingLayer(input_size, output_size, threshold=1.0, v_reset=0.0, leak=0.9, trace_decay=0.95, dt=1.0, **kwargs)[source]

Bases: src.canns.models.brain_inspired._base.BrainInspiredModel

Simple Leaky Integrate-and-Fire (LIF) spiking neuron layer.

This model provides a minimal spiking neuron implementation for demonstrating spike-timing-dependent plasticity (STDP). It features: - Leaky integration of input currents - Threshold-based spike generation - Reset mechanism after spiking - Exponential spike traces for STDP learning

Dynamics:

v[t+1] = leak * v[t] + W @ x[t] spike = 1 if v >= threshold else 0 v = v_reset if spike else v trace = decay * trace + spike

References

  • Gerstner & Kistler (2002): Spiking Neuron Models

  • Morrison et al. (2008): Phenomenological models of synaptic plasticity

Initialize the spiking layer.

Parameters:
  • input_size (int) – Number of input neurons

  • output_size (int) – Number of output neurons

  • threshold (float) – Spike threshold for membrane potential

  • v_reset (float) – Reset potential after spike

  • leak (float) – Membrane leak factor (0-1, closer to 1 = less leaky)

  • trace_decay (float) – Decay factor for spike traces (used in STDP)

  • dt (float) – Time step size

  • **kwargs – Additional arguments passed to parent class

forward(x)[source]

Forward pass through the spiking layer.

Parameters:

x (jax.numpy.ndarray) – Input spikes of shape (input_size,) with binary values (0 or 1)

Returns:

Output spikes of shape (output_size,) with binary values (0 or 1)

Return type:

jax.numpy.ndarray

reset_state()[source]

Reset membrane potentials and spike traces.

update(prev_energy)[source]

Update method for trainer compatibility (no-op for spiking layer).

W
dt = 1.0
property energy: float

Energy for trainer compatibility (0 for spiking layer).

input_size
leak = 0.9
output_size
property predict_state_attr: str

Name of output state for prediction.

spike
threshold = 1.0
trace_decay = 0.95
trace_post
trace_pre
v
v_reset = 0.0
property weight_attr: str

Name of weight parameter for generic training.

x