src.canns.models.brain_inspired¶
Brain-inspired neural network models.
This module contains biologically plausible neural network models that incorporate principles from neuroscience and cognitive science, including associative memory, Hebbian learning, and other brain-inspired mechanisms.
Submodules¶
Classes¶
Amari-Hopfield Network implementation supporting both discrete and continuous dynamics. |
|
Base class for brain-inspired models. |
|
Base class for groups of brain-inspired models. |
|
Generic linear feedforward layer supporting multiple brain-inspired learning rules. |
|
Simple Leaky Integrate-and-Fire (LIF) spiking neuron layer. |
Package Contents¶
- class src.canns.models.brain_inspired.AmariHopfieldNetwork(num_neurons, asyn=False, threshold=0.0, activation='sign', temperature=1.0, **kwargs)[source]¶
Bases:
src.canns.models.brain_inspired._base.BrainInspiredModelAmari-Hopfield Network implementation supporting both discrete and continuous dynamics.
This class implements Hopfield networks with flexible activation functions, supporting both discrete binary states and continuous dynamics. The network performs pattern completion through energy minimization using asynchronous or synchronous updates.
The network energy function: E = -0.5 * Σ_ij W_ij * s_i * s_j
Where s_i can be discrete {-1, +1} or continuous depending on activation function.
- Reference:
Amari, S. (1977). Neural theory of association and concept-formation. Biological Cybernetics, 26(3), 175-185.
Hopfield, J. J. (1982). Neural networks and physical systems with emergent collective computational abilities. Proceedings of the National Academy of Sciences of the USA, 79(8), 2554-2558.
Initialize the Amari-Hopfield Network.
- Parameters:
num_neurons (int) – Number of neurons in the network
asyn (bool) – Whether to run asynchronously or synchronously
threshold (float) – Threshold for activation function
activation (str) – Activation function type (“sign”, “tanh”, “sigmoid”)
temperature (float) – Temperature parameter for continuous activations
**kwargs – Additional arguments passed to parent class
- compute_overlap(pattern1, pattern2)[source]¶
Compute overlap between two binary patterns.
- Parameters:
pattern1 – Binary patterns to compare
pattern2 – Binary patterns to compare
- Returns:
Overlap value (1 for identical, 0 for orthogonal, -1 for opposite)
- resize(num_neurons, preserve_submatrix=True)[source]¶
Resize the network dimension and state/weights.
- W¶
- activation¶
- asyn = False¶
- property energy¶
Compute the energy of the network state.
- num_neurons¶
- s¶
- property storage_capacity¶
Get theoretical storage capacity.
- Returns:
Theoretical storage capacity (approximately N/(4*ln(N)))
- temperature = 1.0¶
- threshold = 0.0¶
- class src.canns.models.brain_inspired.BrainInspiredModel¶
Bases:
src.canns.models.basic._base.BasicModelBase class for brain-inspired models.
Trainer compatibility notes - If a model wants to support generic Hebbian training, expose a weight parameter
attribute with a
.valuearray of shape (N, N) (commonly abm.Variable). The recommended attribute name isW.Override
weight_attrto declare a different attribute name if needed. Models that use standard backprop may omit this entirely.Implementing
apply_hebbian_learningis optional; prefer letting the trainer handle the generic rule when applicable. Implement this only when you need model-specific behavior.
Notes on Predict compatibility - For the trainer’s generic prediction path, models typically expose:
an
update(prev_energy)method to advance one step (optional; not all models require energy-driven updates),an
energyproperty to compute current energy (scalar-like),a state vector attribute (default
s) with.valueas 1D array used as the prediction state; overridepredict_state_attrto change the name.
Optional resizing - Models may implement
resize(num_neurons: int, preserve_submatrix: bool = True)toallow trainers to change neuron dimensionality on the fly (e.g., when training with patterns of a different length). When implemented, the trainer will call this to align dimensions before training/prediction.
- abstractmethod apply_hebbian_learning(train_data)¶
Optional model-specific Hebbian learning implementation.
The generic
HebbianTrainercan updateWdirectly without requiring this method. Only implement when custom behavior deviates from the generic rule.
- abstractmethod predict(pattern)¶
- abstractmethod resize(num_neurons, preserve_submatrix=True)¶
Optional method to resize model state/parameters to
num_neurons.Default implementation is a stub. Subclasses may override to support dynamic dimensionality changes.
- property energy: float¶
- Abstractmethod:
Current energy of the model state (used for convergence checks in prediction).
Implementations may return a float or a 0-dim array; the trainer treats it as a scalar.
- class src.canns.models.brain_inspired.BrainInspiredModelGroup¶
Bases:
src.canns.models.basic._base.BasicModelGroupBase class for groups of brain-inspired models.
This class manages collections of brain-inspired models and provides coordinated learning and dynamics across multiple model instances.
- class src.canns.models.brain_inspired.LinearLayer(input_size, output_size, use_bcm_threshold=False, threshold_tau=100.0, **kwargs)[source]¶
Bases:
src.canns.models.brain_inspired._base.BrainInspiredModelGeneric linear feedforward layer supporting multiple brain-inspired learning rules.
This model provides a simple linear transformation with optional sliding threshold for BCM-style plasticity. It can be used with various trainers: - OjaTrainer: Normalized Hebbian learning for PCA - BCMTrainer: Sliding threshold plasticity (requires use_bcm_threshold=True) - HebbianTrainer: Standard Hebbian learning
- Computation:
y = W @ x
where W is the weight matrix, x is the input, and y is the output.
- For BCM learning, an optional sliding threshold θ tracks output activity:
θ ← θ + (1/τ) * (y² - θ)
References
Oja (1982): Simplified neuron model as a principal component analyzer
Bienenstock et al. (1982): Theory for the development of neuron selectivity
Initialize the linear layer.
- Parameters:
input_size (int) – Dimensionality of input vectors
output_size (int) – Number of output neurons (features to extract)
use_bcm_threshold (bool) – Whether to maintain sliding threshold for BCM learning
threshold_tau (float) – Time constant for threshold sliding average (only used if use_bcm_threshold=True)
**kwargs – Additional arguments passed to parent class
- forward(x)[source]¶
Forward pass through the layer.
- Parameters:
x (jax.numpy.ndarray) – Input vector of shape (input_size,)
- Returns:
Output vector of shape (output_size,)
- Return type:
jax.numpy.ndarray
- update_threshold()[source]¶
Update the sliding threshold based on recent activity (BCM only).
This method should be called by BCMTrainer after each forward pass. Updates θ using: θ ← θ + (1/τ) * (y² - θ)
- W¶
- input_size¶
- output_size¶
- threshold_tau = 100.0¶
- use_bcm_threshold = False¶
- x¶
- y¶
- class src.canns.models.brain_inspired.SpikingLayer(input_size, output_size, threshold=1.0, v_reset=0.0, leak=0.9, trace_decay=0.95, dt=1.0, **kwargs)[source]¶
Bases:
src.canns.models.brain_inspired._base.BrainInspiredModelSimple Leaky Integrate-and-Fire (LIF) spiking neuron layer.
This model provides a minimal spiking neuron implementation for demonstrating spike-timing-dependent plasticity (STDP). It features: - Leaky integration of input currents - Threshold-based spike generation - Reset mechanism after spiking - Exponential spike traces for STDP learning
- Dynamics:
v[t+1] = leak * v[t] + W @ x[t] spike = 1 if v >= threshold else 0 v = v_reset if spike else v trace = decay * trace + spike
References
Gerstner & Kistler (2002): Spiking Neuron Models
Morrison et al. (2008): Phenomenological models of synaptic plasticity
Initialize the spiking layer.
- Parameters:
input_size (int) – Number of input neurons
output_size (int) – Number of output neurons
threshold (float) – Spike threshold for membrane potential
v_reset (float) – Reset potential after spike
leak (float) – Membrane leak factor (0-1, closer to 1 = less leaky)
trace_decay (float) – Decay factor for spike traces (used in STDP)
dt (float) – Time step size
**kwargs – Additional arguments passed to parent class
- forward(x)[source]¶
Forward pass through the spiking layer.
- Parameters:
x (jax.numpy.ndarray) – Input spikes of shape (input_size,) with binary values (0 or 1)
- Returns:
Output spikes of shape (output_size,) with binary values (0 or 1)
- Return type:
jax.numpy.ndarray
- W¶
- dt = 1.0¶
- input_size¶
- leak = 0.9¶
- output_size¶
- spike¶
- threshold = 1.0¶
- trace_decay = 0.95¶
- trace_post¶
- trace_pre¶
- v¶
- v_reset = 0.0¶
- x¶