Source code for src.canns.models.basic.hierarchical_model

import brainpy as bp
import brainpy.math as bm
import jax
import jax.numpy as jnp

from ...task.open_loop_navigation import map2pi
from ._base import BasicModel, BasicModelGroup

__all__ = [
    # Base Units
    "GaussRecUnits",
    "NonRecUnits",
    # Band Cell and Grid Cell Models
    "BandCell",
    "GridCell",
    # Hierarchical Path Integration Model
    "HierarchicalPathIntegrationModel",
    # Hierarchical Network
    "HierarchicalNetwork",
]


[docs] class GaussRecUnits(BasicModel): """A model of recurrently connected units with Gaussian connectivity. This class implements a 1D continuous attractor neural network (CANN). The network maintains a stable "bump" of activity that can represent a continuous variable, such as heading direction. The connectivity between neurons is Gaussian, and the network dynamics include divisive normalization. Attributes: size (int): The number of neurons in the network. tau (float): The time constant for the synaptic input `u`. k (float): The inhibition strength for divisive normalization. a (float): The width of the Gaussian connection profile. noise_0 (float): The standard deviation of the Gaussian noise added to the system. z_min (float): The minimum value of the encoded feature space. z_max (float): The maximum value of the encoded feature space. z_range (float): The range of the feature space (z_max - z_min). x (bm.math.ndarray): The preferred feature values for each neuron. rho (float): The neural density (number of neurons per unit of feature space). dx (float): The stimulus density (feature space range per neuron). J (float): The final connection strength, scaled by J0. conn_mat (bm.math.ndarray): The connection matrix. r (bm.Variable): The firing rates of the neurons. u (bm.Variable): The synaptic inputs to the neurons. center (bm.Variable): The decoded center of the activity bump. input (bm.Variable): The external input to the network. """ def __init__( self, size: int, tau: float = 1.0, J0: float = 1.1, k: float = 5e-4, a: float = 2 / 9 * bm.pi, z_min: float = -bm.pi, z_max: float = bm.pi, noise: float = 2.0, ): """Initializes the GaussRecUnits model. Args: size (int): The number of neurons in the network. tau (float, optional): The time constant of the neurons. Defaults to 1.0. J0 (float, optional): A scaling factor for the critical connection strength. Defaults to 1.1. k (float, optional): The strength of the global inhibition. Defaults to 5e-4. a (float, optional): The width of the Gaussian connection profile. Defaults to 2/9*pi. z_min (float, optional): The minimum value of the feature space. Defaults to -pi. z_max (float, optional): The maximum value of the feature space. Defaults to pi. noise (float, optional): The level of noise in the system. Defaults to 2.0. """
[docs] self.size = size
super().__init__()
[docs] self.tau = tau # The time constant
[docs] self.k = k # The inhibition strength
[docs] self.a = a # The width of the Gaussian connection
[docs] self.noise_0 = noise # The noise level
# feature space
[docs] self.z_min = z_min
[docs] self.z_max = z_max
[docs] self.z_range = z_max - z_min
[docs] self.x = bm.linspace(z_min, z_max, size, endpoint=False) # The encoded feature values
[docs] self.rho = size / self.z_range # The neural density
[docs] self.dx = self.z_range / size # The stimulus density
[docs] self.J = J0 * self.Jc() # The connection strength
[docs] self.conn_mat = self.make_conn() # The connection matrix
[docs] self.r = bm.Variable(bm.zeros(self.size)) # The neural firing rate
[docs] self.u = bm.Variable(bm.zeros(self.size)) # The neural synaptic input
[docs] self.center = bm.Variable( bm.zeros( 1, ) ) # The center of the bump
[docs] self.input = bm.Variable(bm.zeros(self.size)) # The external input
# initialize the neural activity self.u.value = ( 10.0 * bm.exp(-0.5 * bm.square((self.x - 0) / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) ) self.r.value = ( 30.0 * bm.exp(-0.5 * bm.square((self.x - 0) / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) ) # make the connection matrix
[docs] def make_conn(self): """Constructs the periodic Gaussian connection matrix. The connection strength between two neurons depends on the periodic distance between their preferred feature values, following a Gaussian profile. """ dis = self.x[:, None] - self.x[None, :] d = self.dist(dis) return self.J * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
# critical connection strength
[docs] def Jc(self): """Calculates the critical connection strength. This is the minimum connection strength required to sustain a stable activity bump in the attractor network. """ return bm.sqrt(8 * bm.sqrt(2 * bm.pi) * self.k * self.a / self.rho)
# truncate the distance into the range of feature space
[docs] def dist(self, d): """Calculates the periodic distance in the feature space. This function wraps distances to ensure they fall within the periodic boundaries of the feature space, i.e., [-z_range/2, z_range/2]. Args: d (bm.math.ndarray): The array of distances. """ d = bm.remainder(d, self.z_range) d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d) return d
# decode the neural activity
[docs] def decode(self, r, axis=0): """Decodes the center of the activity bump. This method uses a population vector average to compute the center of the neural activity bump from the firing rates. Args: r (Array): The firing rates of the neurons. axis (int, optional): The axis along which to perform the decoding. Defaults to 0. Returns: float: The angle representing the decoded center of the bump. """ expo_r = bm.exp(1j * self.x) * r return bm.angle(bm.sum(expo_r, axis=axis) / bm.sum(r, axis=axis))
# update the neural activity
[docs] def update(self, input): self.input.value = input r1 = bm.square(self.u.value) r2 = 1.0 + self.k * bm.sum(r1) self.r.value = r1 / r2 Irec = bm.dot(self.conn_mat, self.r.value) self.u.value = ( self.u.value + (-self.u.value + Irec + self.input.value) / self.tau * bm.get_dt() ) self.input.value = self.input.value.at[:].set(0.0) self.center.value = self.center.value.at[0].set(self.decode(self.u.value))
[docs] class NonRecUnits(BasicModel): """A model of non-recurrently connected units. This class implements a simple leaky integrator model for a population of neurons that do not have recurrent connections among themselves. They respond to external inputs and have a non-linear activation function. Attributes: size (int): The number of neurons. noise_0 (float): The standard deviation of the Gaussian noise. tau (float): The time constant for the synaptic input `u`. z_min (float): The minimum value of the encoded feature space. z_max (float): The maximum value of the encoded feature space. z_range (float): The range of the feature space. x (bm.ndarray): The preferred feature values for each neuron. rho (float): The neural density. dx (float): The stimulus density. r (bm.Variable): The firing rates of the neurons. u (bm.Variable): The synaptic inputs to the neurons. input (bm.Variable): The external input to the neurons. """ def __init__( self, size: int, tau: float = 0.1, z_min: float = -bm.pi, z_max: float = bm.pi, noise: float = 2.0, ): """Initializes the NonRecUnits model. Args: size (int): The number of neurons. tau (float, optional): The time constant of the neurons. Defaults to 0.1. z_min (float, optional): The minimum value of the feature space. Defaults to -pi. z_max (float, optional): The maximum value of the feature space. Defaults to pi. noise (float, optional): The level of noise in the system. Defaults to 2.0. """ super().__init__()
[docs] self.size = size
[docs] self.noise_0 = noise # The noise level
[docs] self.tau = tau # The time constant
# feature space
[docs] self.z_min = z_min
[docs] self.z_max = z_max
[docs] self.z_range = z_max - z_min
[docs] self.x = bm.linspace(z_min, z_max, size, endpoint=False) # The encoded feature values
[docs] self.rho = size / self.z_range # The neural density
[docs] self.dx = self.z_range / size # The stimulus density
[docs] self.r = bm.Variable(bm.zeros(self.size)) # The neural firing rate
[docs] self.u = bm.Variable(bm.zeros(self.size)) # The neural synaptic input
[docs] self.input = bm.Variable(bm.zeros(self.size)) # The external input
# choose the activation function
[docs] def activate(self, x): """Applies an activation function to the input. Args: x (Array): The input to the activation function (e.g., synaptic input `u`). Returns: Array: The result of the activation function (ReLU). """ return bm.relu(x)
[docs] def dist(self, d): """Calculates the periodic distance in the feature space. This function wraps distances to ensure they fall within the periodic boundaries of the feature space. Args: d (Array): The array of distances. Returns: Array: The wrapped distances. """ d = bm.remainder(d, self.z_range) d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d) return d
[docs] def update(self, input): self.input.value = input self.r.value = bm.where( self.noise_0 != 0.0, self.activate(self.u.value) + self.noise_0 * bm.random.randn(self.size), self.activate(self.u.value), ) # self.r.value = self.activate(self.u.value) + self.noise_0 * bm.random.randn( # self.size # ) self.u.value = self.u.value + (-self.u.value + self.input.value) / self.tau * bm.get_dt() self.input.value = self.input.value.at[:].set(0.0) return self.r.value
# the intact networks contains a group of EPG neurons (recurrent units), two P-EN neurons (non-recurrent units), one group of # FC2 (recurrent units), two PFL3 (non-recurrent units) and two DN neurons (non-recurrent units)
[docs] class BandCell(BasicModel): """A model of a band cell module for path integration. This model represents a set of neurons whose receptive fields form parallel bands across a 2D space. It is composed of a central `GaussRecUnits` attractor network (the band cells proper) that represents a 1D phase, and two `NonRecUnits` populations (left and right) that help shift the activity in the attractor network based on velocity input. This mechanism allows the module to integrate the component of velocity along its preferred direction. Attributes: size (int): The number of neurons in each sub-population. spacing (float): The spacing between the bands in the 2D environment. angle (float): The orientation angle of the bands. proj_k (bm.math.ndarray): The projection vector for converting 2D position/velocity to 1D phase. band_cells (GaussRecUnits): The core recurrent network representing the phase. left (NonRecUnits): A population of non-recurrent units for positive shifts. right (NonRecUnits): A population of non-recurrent units for negative shifts. w_L2S (float): Connection weight from band cells to left/right units. w_S2L (float): Connection weight from left/right units to band cells. gain (float): A gain factor for velocity-modulated input. center_ideal (bm.Variable): The ideal, noise-free center based on velocity integration. center (bm.Variable): The actual decoded center of the band cell activity bump. """ def __init__( self, angle, spacing, size=180, z_min=-bm.pi, z_max=bm.pi, noise=2.0, w_L2S=0.2, w_S2L=1.0, gain=0.2, # GaussRecUnits configuration gauss_tau=1.0, gauss_J0=1.1, gauss_k=5e-4, gauss_a=2 / 9 * bm.pi, # NonRecUnits configuration nonrec_tau=0.1, **kwargs, ): """Initializes the BandCell model. Args: angle (float): The orientation angle of the bands. spacing (float): The spacing between the bands. size (int, optional): The number of neurons in each group. Defaults to 180. z_min (float, optional): The minimum value of the feature space (phase). Defaults to -pi. z_max (float, optional): The maximum value of the feature space (phase). Defaults to pi. noise (float, optional): The noise level for the neuron groups. Defaults to 2.0. w_L2S (float, optional): Weight from band cells to shifter units. Defaults to 0.2. w_S2L (float, optional): Weight from shifter units to band cells. Defaults to 1.0. gain (float, optional): A gain factor for the velocity signal. Defaults to 0.2. gauss_tau (float, optional): Time constant for GaussRecUnits. Defaults to 1.0. gauss_J0 (float, optional): Connection strength scaling factor for GaussRecUnits. Defaults to 1.1. gauss_k (float, optional): Global inhibition strength for GaussRecUnits. Defaults to 5e-4. gauss_a (float, optional): Gaussian connection width for GaussRecUnits. Defaults to 2/9*pi. nonrec_tau (float, optional): Time constant for NonRecUnits. Defaults to 0.1. **kwargs: Additional keyword arguments for the base class. """
[docs] self.size = size # The number of neurons in each neuron group except DN
super().__init__(**kwargs) # feature space
[docs] self.z_min = z_min
[docs] self.z_max = z_max
[docs] self.z_range = z_max - z_min
[docs] self.x = bm.linspace(z_min, z_max, size, endpoint=False) # The encoded feature values
[docs] self.rho = size / self.z_range # The neural density
[docs] self.dx = self.z_range / size # The stimulus density
[docs] self.spacing = spacing
[docs] self.angle = angle
[docs] self.proj_k = ( bm.array([bm.cos(angle - bm.pi / 2), bm.sin(angle - bm.pi / 2)]) * 2 * bm.pi / spacing )
# shifts
[docs] self.phase_shift = 1 / 9 * bm.pi * 0.76 # the shift of the connection from PEN to EPG
# self.PFL3_shift = 3/8*bm.pi # the shift of the connection from EPG to PFL3 # self.PEN_shift_num = int(self.PEN_shift / self.dx) # the number of interval shifted # self.PFL3_shift_num = int(self.PFL3_shift / self.dx) # the number of interval shifted # neurons - create with custom parameters
[docs] self.band_cells = GaussRecUnits( size=size, tau=gauss_tau, J0=gauss_J0, k=gauss_k, a=gauss_a, z_min=z_min, z_max=z_max, noise=noise, ) # heading direction
[docs] self.left = NonRecUnits(size=size, tau=nonrec_tau, z_min=z_min, z_max=z_max, noise=noise)
[docs] self.right = NonRecUnits(size=size, tau=nonrec_tau, z_min=z_min, z_max=z_max, noise=noise)
# weights
[docs] self.w_L2S = w_L2S
[docs] self.w_S2L = w_S2L
[docs] self.gain = gain
self.synapses()
[docs] self.center_ideal = bm.Variable( bm.zeros( 1, ) ) # The center of v-
[docs] self.center = bm.Variable( bm.zeros( 1, ) ) # The center of v-
# define the synapses
[docs] def synapses(self): """Defines the synaptic connections between the neuron groups. This method sets up the shifted connections from the left/right shifter populations to the central band cell attractor network, as well as the one-to-one connections from the band cells to the shifters. """ self.W_PENl2EPG = self.w_S2L * self.make_conn(self.phase_shift) self.W_PENr2EPG = self.w_S2L * self.make_conn(-self.phase_shift) # synapses self.syn_Band2Left = bp.dnn.OneToOne(self.size, self.w_L2S) self.syn_Band2Right = bp.dnn.OneToOne(self.size, self.w_L2S) self.syn_Left2Band = bp.dnn.Linear(self.size, self.size, self.W_PENl2EPG) self.syn_Right2Band = bp.dnn.Linear(self.size, self.size, self.W_PENr2EPG)
[docs] def dist(self, d): """Calculates the periodic distance in the feature space. Args: d (Array): The array of distances. Returns: Array: The wrapped distances. """ d = bm.remainder(d, self.z_range) d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d) return d
[docs] def make_conn(self, shift): """Creates a shifted Gaussian connection profile. This is used to create the connections from the left/right shifter units to the band cells, which implements the bump-shifting mechanism. Args: shift (float): The amount to shift the connection profile by. Returns: Array: The shifted connection matrix. """ d = self.dist(self.x[:, None] - self.x[None, :] + shift) return bm.exp(-0.5 * bm.square(d / self.band_cells.a)) / ( bm.sqrt(2 * bm.pi) * self.band_cells.a )
[docs] def Postophase(self, pos): """Projects a 2D position to a 1D phase. This function converts a 2D coordinate in the environment into a 1D phase value based on the band cell's preferred angle and spacing. Args: pos (Array): The 2D position vector. Returns: float: The corresponding 1D phase. """ phase = bm.mod(bm.dot(pos, self.proj_k), 2 * bm.pi) - bm.pi return phase
[docs] def get_stimulus_by_pos(self, pos): """Generates a stimulus input based on a 2D position. This creates a Gaussian bump of input centered on the phase corresponding to the given position, which can be used to anchor the network's activity. Args: pos (Array): The 2D position vector. Returns: Array: The stimulus input vector for the band cells. """ phase = self.Postophase(pos) d = self.dist(phase - self.x) return bm.exp(-0.25 * bm.square(d / self.band_cells.a))
# move the heading direction representation (for testing)
[docs] def move_heading(self, shift): """Manually shifts the activity bump in the band cells. This is a utility function for testing purposes. Args: shift (int): The number of neurons to roll the activity by. """ self.band_cells.r.value = bm.roll(self.band_cells.r, shift) self.band_cells.u.value = bm.roll(self.band_cells.u, shift)
[docs] def get_center(self): """Decodes and updates the current center of the band cell activity.""" exppos = bm.exp(1j * self.x) r = self.band_cells.r.value self.center.value = bm.angle(bm.atleast_1d(bm.sum(exppos * r)))
[docs] def reset(self): """Resets the synaptic inputs of the left and right shifter units.""" self.left.u.value = bm.zeros(self.size) self.right.u.value = bm.zeros(self.size)
[docs] def update(self, velocity, loc, loc_input_stre): """Updates the BandCell module for one time step. It integrates the component of `velocity` along the module's preferred direction to update the phase representation. The activity bump is shifted by modulating the inputs from the left/right shifter populations. It can also incorporate a direct location-based input. Args: velocity (Array): The 2D velocity vector. loc (Array): The current 2D location. loc_input_stre (float): The strength of the location-based input. """ loc_input = jax.lax.cond( loc_input_stre != 0.0, lambda op: self.get_stimulus_by_pos(op[0]) * op[1], lambda op: bm.zeros(self.size, dtype=float).value, operand=(loc, loc_input_stre), ) # if loc_input_stre != 0.: # loc_input = self.get_stimulus_by_pos(loc) * loc_input_stre # else: # loc_input = bm.zeros(self.size) v_phi = bm.dot(velocity, self.proj_k) center_ideal = self.center_ideal.value + v_phi * bm.get_dt() self.center_ideal.value = map2pi(center_ideal) # EPG output last time step Band_output = self.band_cells.r.value # PEN input left_input = self.syn_Band2Left(Band_output) right_input = self.syn_Band2Right(Band_output) # PEN output and gain self.left(left_input) self.right(right_input) self.left.r.value = (self.gain + v_phi) * self.left.r.value self.right.r.value = (self.gain - v_phi) * self.right.r.value # EPG input Band_input = self.syn_Left2Band(self.left.r.value) + self.syn_Right2Band(self.right.r.value) # EPG output self.band_cells(Band_input + loc_input) # self.Band_cells.update(loc_input) self.get_center()
# Grid cell model modules
[docs] class GridCell(BasicModel): """A model of a grid cell module using a 2D continuous attractor network. This class implements a 2D continuous attractor network on a toroidal manifold to model the firing patterns of grid cells. The network dynamics include synaptic depression or adaptation, which helps stabilize the activity bumps. The connectivity is defined on a hexagonal grid structure. Attributes: num (int): The total number of neurons (num_side x num_side). tau (float): The synaptic time constant for `u`. tau_v (float): The time constant for the adaptation variable `v`. k (float): The degree of rescaled inhibition. a (float): The half-width of the excitatory connection range. A (float): The magnitude of the external input. J0 (float): The maximum connection value. m (float): The strength of the adaptation. angle (float): The orientation of the grid. value_grid (bm.math.ndarray): The (x, y) preferred phase coordinates for each neuron. conn_mat (bm.math.ndarray): The connection matrix. r (bm.Variable): The firing rates of the neurons. u (bm.Variable): The synaptic inputs to the neurons. v (bm.Variable): The adaptation variables for the neurons. center (bm.Variable): The decoded 2D center of the activity bump. """ def __init__( self, num, angle, spacing, tau=0.1, tau_v=10.0, k=5e-3, a=bm.pi / 9, A=1.0, J0=1.0, mbar=1.0, ): """Initializes the GridCell model. Args: num (int): The number of neurons along one dimension of the square grid. angle (float): The orientation angle of the grid pattern. spacing (float): The spacing of the grid pattern. tau (float, optional): The synaptic time constant. Defaults to 0.1. tau_v (float, optional): The adaptation time constant. Defaults to 10.0. k (float, optional): The strength of global inhibition. Defaults to 5e-3. a (float, optional): The width of the connection profile. Defaults to pi/9. A (float, optional): The magnitude of external input. Defaults to 1.0. J0 (float, optional): The maximum connection strength. Defaults to 1.0. mbar (float, optional): The base strength of adaptation. Defaults to 1.0. """
[docs] self.num = num**2
super().__init__() # dynamics parameters
[docs] self.tau = tau # The synaptic time constant
[docs] self.tau_v = tau_v
# self.w_max = w_max
[docs] self.ratio = bm.pi * 2 / spacing
[docs] self.k = k # Degree of the rescaled inhibition
[docs] self.a = a # Half-width of the range of excitatory connections
[docs] self.A = A # Magnitude of the external input
[docs] self.J0 = J0 # maximum connection value
[docs] self.m = mbar * tau / tau_v
[docs] self.angle = angle
# feature space
[docs] self.x_range = 2 * bm.pi
[docs] self.x = bm.linspace(-bm.pi, bm.pi, num, endpoint=False)
x_grid, y_grid = bm.meshgrid(self.x, self.x)
[docs] self.x_grid = x_grid.flatten()
[docs] self.y_grid = y_grid.flatten()
[docs] self.value_grid = bm.stack([self.x_grid, self.y_grid]).T
[docs] self.rho = self.num / (self.x_range**2) # The neural density
[docs] self.dxy = 1 / self.rho # The stimulus density
[docs] self.coor_transform = bm.array([[1, -1 / bm.sqrt(3)], [0, 2 / bm.sqrt(3)]])
[docs] self.rot = bm.array( [ [bm.cos(self.angle), -bm.sin(self.angle)], [bm.sin(self.angle), bm.cos(self.angle)], ] )
# initialize conn matrix
[docs] self.conn_mat = self.make_conn()
[docs] self.r = bm.Variable(bm.zeros(self.num))
[docs] self.u = bm.Variable(bm.zeros(self.num))
[docs] self.v = bm.Variable(bm.zeros(self.num))
[docs] self.input = bm.Variable(bm.zeros(self.num))
[docs] self.center = bm.Variable( bm.zeros( 2, ) )
[docs] self.integral = bp.odeint(method="exp_euler", f=self.derivative)
@property
[docs] def derivative(self): du = lambda u, t, Irec: (-u + Irec + self.input - self.v) / self.tau dv = lambda v, t: (-v + self.m * self.u) / self.tau_v return bp.JointEq([du, dv])
[docs] def reset_state(self, *args, **kwargs): """Resets the state variables of the model to zeros.""" self.r.value = bm.zeros(self.num) self.u.value = bm.zeros(self.num) self.v.value = bm.zeros(self.num) self.input.value = bm.zeros(self.num) self.center.value = bm.zeros( 2, )
[docs] def dist(self, d): """Calculates the distance on the hexagonal grid. It first maps the periodic difference vector `d` into a Cartesian coordinate system that reflects the hexagonal lattice structure and then computes the Euclidean distance. Args: d (Array): An array of difference vectors in the phase space. Returns: Array: The corresponding distances on the hexagonal lattice. """ d = map2pi(d) delta_x = d[:, 0] delta_y = (d[:, 1] - 1 / 2 * d[:, 0]) * 2 / bm.sqrt(3) return bm.sqrt(delta_x**2 + delta_y**2)
[docs] def make_conn(self): """Constructs the connection matrix for the 2D attractor network. The connection strength between two neurons is a Gaussian function of the hexagonal distance between their preferred phases. Returns: Array: The connection matrix (num x num). """ @jax.vmap def get_J(v): d = self.dist(v - self.value_grid) Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) return Jxx return get_J(self.value_grid)
[docs] def circle_period(self, d): """Wraps values into the periodic range [-pi, pi]. Args: d (Array): The input values. Returns: Array: The wrapped values. """ d = bm.where(d > bm.pi, d - 2 * bm.pi, d) d = bm.where(d < -bm.pi, d + 2 * bm.pi, d) return d
[docs] def get_center(self): """Decodes and updates the 2D center of the activity bump. It uses a population vector average for both the x and y dimensions of the phase space. """ exppos_x = bm.exp(1j * self.x_grid) exppos_y = bm.exp(1j * self.y_grid) r = bm.where(self.r.value > bm.max(self.r.value) * 0.1, self.r.value, 0) self.center.value = bm.asarray( [bm.angle(bm.sum(exppos_x * r)), bm.angle(bm.sum(exppos_y * r))] )
[docs] def update(self, input): self.input.value = input Irec = bm.dot(self.conn_mat, self.r.value) # Update neural state _u, _v = self.integral(self.u, self.v, bp.share["t"], Irec, bm.dt) self.u.value = bm.where(_u > 0, _u, 0) self.v.value = _v # self.u.value += ( # (-self.u.value + Irec + self.input.value - self.v.value) # / self.tau # * bm.get_dt() # ) # self.u.value = bm.where(self.u.value > 0, self.u.value, 0) # self.v.value += ( # (-self.v.value + self.m * self.u.value) / self.tau_v * bm.get_dt() # ) r1 = bm.square(self.u.value) r2 = 1.0 + self.k * bm.sum(r1) self.r.value = r1 / r2 self.get_center()
[docs] class HierarchicalPathIntegrationModel(BasicModelGroup): """A hierarchical model combining band cells and grid cells for path integration. This model forms a single grid module. It consists of three `BandCell` modules, each with a different preferred orientation (separated by 60 degrees), and one `GridCell` module. The band cells integrate velocity along their respective directions, and their combined outputs provide the input to the `GridCell` network, effectively driving the grid cell's activity bump. The model can also project its grid cell activity to a population of place cells. Attributes: band_cell_x (BandCell): The first band cell module (orientation `angle`). band_cell_y (BandCell): The second band cell module (orientation `angle` + 60 deg). band_cell_z (BandCell): The third band cell module (orientation `angle` + 120 deg). grid_cell (GridCell): The grid cell module driven by the band cells. place_center (bm.math.ndarray): The center locations of the target place cells. Wg2p (bm.math.ndarray): The connection weights from grid cells to place cells. grid_output (bm.Variable): The activity of the place cells. """ def __init__( self, spacing, angle, place_center=None, # BandCell configuration band_size=180, band_noise=0.0, band_w_L2S=0.2, band_w_S2L=1.0, band_gain=0.2, # GridCell configuration grid_num=20, grid_tau=0.1, grid_tau_v=10.0, grid_k=5e-3, grid_a=bm.pi / 9, grid_A=1.0, grid_J0=1.0, grid_mbar=1.0, # GaussRecUnits configuration (for BandCell) gauss_tau=1.0, gauss_J0=1.1, gauss_k=5e-4, gauss_a=2 / 9 * bm.pi, # NonRecUnits configuration (for BandCell) nonrec_tau=0.1, ): """Initializes the HierarchicalPathIntegrationModel. Args: spacing (float): The spacing of the grid pattern for this module. angle (float): The base orientation angle for the module. place_center (bm.math.ndarray, optional): The center locations of the target place cell population. Defaults to a random distribution. band_size (int, optional): Number of neurons in each BandCell group. Defaults to 180. band_noise (float, optional): Noise level for BandCells. Defaults to 0.0. band_w_L2S (float, optional): Weight from band cells to shifter units. Defaults to 0.2. band_w_S2L (float, optional): Weight from shifter units to band cells. Defaults to 1.0. band_gain (float, optional): Gain factor for velocity signal in BandCells. Defaults to 0.2. grid_num (int, optional): Number of neurons per dimension for GridCell. Defaults to 20. grid_tau (float, optional): Synaptic time constant for GridCell. Defaults to 0.1. grid_tau_v (float, optional): Adaptation time constant for GridCell. Defaults to 10.0. grid_k (float, optional): Global inhibition strength for GridCell. Defaults to 5e-3. grid_a (float, optional): Connection width for GridCell. Defaults to pi/9. grid_A (float, optional): External input magnitude for GridCell. Defaults to 1.0. grid_J0 (float, optional): Maximum connection strength for GridCell. Defaults to 1.0. grid_mbar (float, optional): Base adaptation strength for GridCell. Defaults to 1.0. gauss_tau (float, optional): Time constant for GaussRecUnits in BandCells. Defaults to 1.0. gauss_J0 (float, optional): Connection strength scaling for GaussRecUnits. Defaults to 1.1. gauss_k (float, optional): Global inhibition for GaussRecUnits. Defaults to 5e-4. gauss_a (float, optional): Connection width for GaussRecUnits. Defaults to 2/9*pi. nonrec_tau (float, optional): Time constant for NonRecUnits in BandCells. Defaults to 0.1. """ super().__init__() # Create BandCell instances with custom parameters
[docs] self.band_cell_x = BandCell( angle=angle, spacing=spacing, size=band_size, noise=band_noise, w_L2S=band_w_L2S, w_S2L=band_w_S2L, gain=band_gain, gauss_tau=gauss_tau, gauss_J0=gauss_J0, gauss_k=gauss_k, gauss_a=gauss_a, nonrec_tau=nonrec_tau, )
[docs] self.band_cell_y = BandCell( angle=angle + bm.pi / 3, spacing=spacing, size=band_size, noise=band_noise, w_L2S=band_w_L2S, w_S2L=band_w_S2L, gain=band_gain, gauss_tau=gauss_tau, gauss_J0=gauss_J0, gauss_k=gauss_k, gauss_a=gauss_a, nonrec_tau=nonrec_tau, )
[docs] self.band_cell_z = BandCell( angle=angle + bm.pi / 3 * 2, spacing=spacing, size=band_size, noise=band_noise, w_L2S=band_w_L2S, w_S2L=band_w_S2L, gain=band_gain, gauss_tau=gauss_tau, gauss_J0=gauss_J0, gauss_k=gauss_k, gauss_a=gauss_a, nonrec_tau=nonrec_tau, )
# Create GridCell instance with custom parameters
[docs] self.grid_cell = GridCell( num=grid_num, angle=angle, spacing=spacing, tau=grid_tau, tau_v=grid_tau_v, k=grid_k, a=grid_a, A=grid_A, J0=grid_J0, mbar=grid_mbar, )
[docs] self.proj_k_x = self.band_cell_x.proj_k
[docs] self.proj_k_y = self.band_cell_y.proj_k
[docs] self.place_center = ( place_center if place_center is not None else 10 * bm.random.rand(512, 2) )
self.make_conn() self.make_Wg2p()
[docs] self.num_place = place_center.shape[0]
[docs] self.coor_transform = bm.array([[1, -1 / bm.sqrt(3)], [0, 2 / bm.sqrt(3)]])
[docs] self.grid_output = bm.Variable(bm.zeros(self.num_place))
[docs] def make_conn(self): """Creates the connection matrices from the band cells to the grid cells. The connection from a band cell to a grid cell is strong if the grid cell's preferred phase along the band cell's direction matches the band cell's preferred phase. """ value_grid = self.grid_cell.value_grid band_x = self.band_cell_x.x band_y = self.band_cell_y.x band_z = self.band_cell_z.x J0 = self.grid_cell.J0 * 0.1 grid_x = value_grid[:, 0] grid_y = value_grid[:, 1] # Calculate the distance between each grid cell and band cell grid_vector = bm.zeros(value_grid.shape) grid_vector = grid_vector.at[:, 0].set(value_grid[:, 0]) grid_vector = grid_vector.at[:, 1].set( (value_grid[:, 1] - 1 / 2 * value_grid[:, 0]) * 2 / bm.sqrt(3) ) z_vector = bm.array([-1 / 2, bm.sqrt(3) / 2]) grid_phase_z = bm.dot(grid_vector, z_vector) dis_x = self.band_cell_x.dist(grid_x[:, None] - band_x[None, :]) dis_y = self.band_cell_y.dist(grid_y[:, None] - band_y[None, :]) dis_z = self.band_cell_z.dist(grid_phase_z[:, None] - band_z[None, :]) self.W_x_grid = ( J0 * bm.exp(-0.5 * bm.square(dis_x / self.band_cell_x.band_cells.a)) / (bm.sqrt(2 * bm.pi) * self.band_cell_x.band_cells.a) ) self.W_y_grid = ( J0 * bm.exp(-0.5 * bm.square(dis_y / self.band_cell_y.band_cells.a)) / (bm.sqrt(2 * bm.pi) * self.band_cell_y.band_cells.a) ) self.W_z_grid = ( J0 * bm.exp(-0.5 * bm.square(dis_z / self.band_cell_z.band_cells.a)) / (bm.sqrt(2 * bm.pi) * self.band_cell_z.band_cells.a) )
[docs] def Postophase(self, pos): """Projects a 2D position to the 2D phase space of the grid module. Args: pos (Array): The 2D position vector. Returns: Array: The corresponding 2D phase vector. """ phase_x = bm.mod(bm.dot(pos, self.proj_k_x), 2 * bm.pi) - bm.pi phase_y = bm.mod(bm.dot(pos, self.proj_k_y), 2 * bm.pi) - bm.pi return bm.array([phase_x, phase_y]).transpose()
[docs] def make_Wg2p(self): """Creates the connection weights from grid cells to place cells. The connection strength is determined by the proximity of a place cell's center to a grid cell's firing field, calculated in the phase domain. """ phase_place = self.Postophase(self.place_center) phase_grid = self.grid_cell.value_grid d = phase_place[:, jnp.newaxis, :] - phase_grid[jnp.newaxis, :, :] d = map2pi(d) delta_x = d[:, :, 0] delta_y = (d[:, :, 1] - 1 / 2 * d[:, :, 0]) * 2 / bm.sqrt(3) # delta_x = d[:,:,0] + d[:,:,1]/2 # delta_y = d[:,:,1] * bm.sqrt(3) / 2 dis = bm.sqrt(delta_x**2 + delta_y**2) Wg2p = bm.exp(-0.5 * bm.square(dis / self.band_cell_x.band_cells.a)) / ( bm.sqrt(2 * bm.pi) * self.band_cell_x.band_cells.a ) self.Wg2p = Wg2p
[docs] def dist(self, d): """Calculates the distance on the hexagonal grid. Args: d (Array): An array of difference vectors in the phase space. Returns: Array: The corresponding distances on the hexagonal lattice. """ d = map2pi(d) delta_x = d[:, 0] delta_y = (d[:, 1] - 1 / 2 * d[:, 0]) * 2 / bm.sqrt(3) return bm.sqrt(delta_x**2 + delta_y**2)
[docs] def get_input(self, Phase): """Generates a stimulus input for the grid cell based on a 2D phase. Args: Phase (Array): The 2D phase vector. Returns: Array: The stimulus input vector for the grid cells. """ dis = self.dist(Phase - self.grid_cell.value_grid) return bm.exp(-0.5 * bm.square(dis / self.band_cell_x.band_cells.a)) / ( bm.sqrt(2 * bm.pi) * self.band_cell_x.band_cells.a )
[docs] def update(self, velocity, loc, loc_input_stre=0.0): self.band_cell_x(velocity=velocity, loc=loc, loc_input_stre=loc_input_stre) self.band_cell_y(velocity=velocity, loc=loc, loc_input_stre=loc_input_stre) self.band_cell_z(velocity=velocity, loc=loc, loc_input_stre=loc_input_stre) band_output = ( self.W_x_grid @ self.band_cell_x.band_cells.r.value + self.W_y_grid @ self.band_cell_y.band_cells.r.value + self.W_z_grid @ self.band_cell_z.band_cells.r.value ) # band_output = (self.W_x_grid @ self.band_cell_x.Band_cells.r + self.W_y_grid @ self.band_cell_y.Band_cells.r) max_output = bm.max(band_output) band_output = bm.where(band_output > max_output / 2, band_output - max_output / 2, 0) phase_x = self.band_cell_x.center.value phase_y = self.band_cell_y.center.value Phase = bm.array([phase_x, phase_y]).transpose() # Phase = self.Postophase(loc) loc_input = self.get_input(Phase) * 5000 self.grid_cell.update(input=loc_input) grid_fr = self.grid_cell.r.value # self.grid_output = bm.dot(self.Wg2p, grid_fr-bm.max(grid_fr)/2) self.grid_output.value = bm.dot(self.Wg2p, grid_fr)
# band_cell_x_states = self.band_cell_x.states() # band_cell_y_states = self.band_cell_y.states() # band_cell_z_states = self.band_cell_z.states() # gird_cell_states = self.grid_cell.states() # # return { # 'band_cell_x': band_cell_x_states, # 'band_cell_y': band_cell_y_states, # 'band_cell_z': band_cell_z_states, # 'grid_cell': gird_cell_states, # # 'gird_fr': gird_cell_states['r'], # 'band_x_fr': band_cell_x_states['band_cells']['r'], # 'band_y_fr': band_cell_y_states['band_cells']['r'], # 'grid_output': self.grid_output, # }
[docs] class HierarchicalNetwork(BasicModelGroup): """A full hierarchical network composed of multiple grid modules. This class creates and manages a collection of `HierarchicalPathIntegrationModel` modules, each with a different grid spacing. By combining the outputs of these modules, the network can represent position unambiguously over a large area. The final output is a population of place cells whose activities are used to decode the animal's estimated position. Attributes: num_module (int): The number of grid modules in the network. num_place (int): The number of place cells in the output layer. place_center (bm.math.ndarray): The center locations of the place cells. MEC_model_list (list): A list containing all the `HierarchicalPathIntegrationModel` instances. grid_fr (bm.Variable): The firing rates of the grid cell population. band_x_fr (bm.Variable): The firing rates of the x-oriented band cell population. band_y_fr (bm.Variable): The firing rates of the y-oriented band cell population. place_fr (bm.Variable): The firing rates of the place cell population. decoded_pos (bm.Variable): The final decoded 2D position. References: Anonymous Author(s) "Unfolding the Black Box of Recurrent Neural Networks for Path Integration" (under review). """ def __init__( self, num_module, num_place, # Module spacing configuration spacing_min=2.0, spacing_max=5.0, module_angle=0.0, # BandCell configuration band_size=180, band_noise=0.0, band_w_L2S=0.2, band_w_S2L=1.0, band_gain=0.2, # GridCell configuration grid_num=20, grid_tau=0.1, grid_tau_v=10.0, grid_k=5e-3, grid_a=bm.pi / 9, grid_A=1.0, grid_J0=1.0, grid_mbar=1.0, # GaussRecUnits configuration (for BandCell) gauss_tau=1.0, gauss_J0=1.1, gauss_k=5e-4, gauss_a=2 / 9 * bm.pi, # NonRecUnits configuration (for BandCell) nonrec_tau=0.1, ): """Initializes the HierarchicalNetwork. Args: num_module (int): The number of grid modules to create. num_place (int): The number of place cells along one dimension of a square grid. spacing_min (float, optional): Minimum spacing for grid modules. Defaults to 2.0. spacing_max (float, optional): Maximum spacing for grid modules. Defaults to 5.0. module_angle (float, optional): Base orientation angle for all modules. Defaults to 0.0. band_size (int, optional): Number of neurons in each BandCell group. Defaults to 180. band_noise (float, optional): Noise level for BandCells. Defaults to 0.0. band_w_L2S (float, optional): Weight from band cells to shifter units. Defaults to 0.2. band_w_S2L (float, optional): Weight from shifter units to band cells. Defaults to 1.0. band_gain (float, optional): Gain factor for velocity signal in BandCells. Defaults to 0.2. grid_num (int, optional): Number of neurons per dimension for GridCell. Defaults to 20. grid_tau (float, optional): Synaptic time constant for GridCell. Defaults to 0.1. grid_tau_v (float, optional): Adaptation time constant for GridCell. Defaults to 10.0. grid_k (float, optional): Global inhibition strength for GridCell. Defaults to 5e-3. grid_a (float, optional): Connection width for GridCell. Defaults to pi/9. grid_A (float, optional): External input magnitude for GridCell. Defaults to 1.0. grid_J0 (float, optional): Maximum connection strength for GridCell. Defaults to 1.0. grid_mbar (float, optional): Base adaptation strength for GridCell. Defaults to 1.0. gauss_tau (float, optional): Time constant for GaussRecUnits in BandCells. Defaults to 1.0. gauss_J0 (float, optional): Connection strength scaling for GaussRecUnits. Defaults to 1.1. gauss_k (float, optional): Global inhibition for GaussRecUnits. Defaults to 5e-4. gauss_a (float, optional): Connection width for GaussRecUnits. Defaults to 2/9*pi. nonrec_tau (float, optional): Time constant for NonRecUnits in BandCells. Defaults to 0.1. """ super().__init__()
[docs] self.num_module = num_module
[docs] self.num_place = num_place**2
# randomly sample num_place place field centers from a square arena (5m x 5m) x = bm.linspace(0, 5, num_place) X, Y = bm.meshgrid(x, x)
[docs] self.place_center = bm.stack([X.flatten(), Y.flatten()]).T
# self.place_center = 5 * bm.random.rand(num_place,2) # load heatmaps_grid from heatmaps_grid.npz # data = np.load('heatmaps_grid.npz', allow_pickle=True) # heatmaps_grid = data['heatmaps_grid'] # print(heatmaps_grid.shape) MEC_model_list = [] # self.W_g2p_list = [] spacing = bm.linspace(spacing_min, spacing_max, num_module) for i in range(num_module): MEC_model_list.append( HierarchicalPathIntegrationModel( spacing=spacing[i], angle=module_angle, place_center=self.place_center, band_size=band_size, band_noise=band_noise, band_w_L2S=band_w_L2S, band_w_S2L=band_w_S2L, band_gain=band_gain, grid_num=grid_num, grid_tau=grid_tau, grid_tau_v=grid_tau_v, grid_k=grid_k, grid_a=grid_a, grid_A=grid_A, grid_J0=grid_J0, grid_mbar=grid_mbar, gauss_tau=gauss_tau, gauss_J0=gauss_J0, gauss_k=gauss_k, gauss_a=gauss_a, nonrec_tau=nonrec_tau, ) ) # W_g2p = self.W_place2grid(heatmaps_grid[i*400:(i+1)*400]) # self.W_g2p_list.append(W_g2p)
[docs] self.MEC_model_list = MEC_model_list
[docs] self.place_fr = bm.Variable(bm.zeros(self.num_place))
[docs] self.grid_fr = bm.Variable(bm.zeros((self.num_module, 20**2)))
[docs] self.band_x_fr = bm.Variable(bm.zeros((self.num_module, 180)))
[docs] self.band_y_fr = bm.Variable(bm.zeros((self.num_module, 180)))
[docs] self.decoded_pos = bm.Variable(bm.zeros(2))
[docs] def update(self, velocity, loc, loc_input_stre=0.0): grid_output = bm.zeros(self.num_place) for i in range(self.num_module): # update the band cell module self.MEC_model_list[i](velocity=velocity, loc=loc, loc_input_stre=loc_input_stre) self.grid_fr.value = self.grid_fr.value.at[i].set( self.MEC_model_list[i].grid_cell.r.value ) self.band_x_fr.value = self.band_x_fr.value.at[i].set( self.MEC_model_list[i].band_cell_x.band_cells.r.value ) self.band_y_fr.value = self.band_y_fr.value.at[i].set( self.MEC_model_list[i].band_cell_y.band_cells.r.value ) grid_output_module = self.MEC_model_list[i].grid_output.value # W_g2p = self.W_g2p_list[i] # grid_fr = self.MEC_model_list[i].Grid_cell.r # grid_output_module = bm.dot(W_g2p, grid_fr) grid_output += grid_output_module # update the place cell module grid_output = bm.where(grid_output > 0, grid_output, 0) u_place = bm.where( grid_output > bm.max(grid_output) / 2, grid_output - bm.max(grid_output) / 2, 0 ) # grid_output = grid_output**2/(1+bm.sum(grid_output**2)) # max_id = bm.argmax(grid_output) # center = self.place_center[max_id] center = bm.sum(self.place_center * u_place[:, jnp.newaxis], axis=0) / ( 1e-5 + bm.sum(u_place) ) self.decoded_pos.value = center self.place_fr.value = u_place**2 / (1 + bm.sum(u_place**2))
# self.place_fr = softmax(grid_output) # the optimized run function is not run well(the performance is not good enough, as the original one), ''' def run(self, indices, velocities, positions, loc_input_stre=0.0, pbar=None): """Runs the hierarchical network for a series of time steps. Args: indices (Array): The indices of the time steps to run. velocities (Array): The 2D velocity vectors at each time step. positions (Array): The 2D position vectors at each time step. loc_input_stre (Array): The strength of the location-based input. p_bar (ProgressBar): A progress bar for tracking the simulation progress. """ band_x_r = bm.zeros((indices.shape[0], self.num_module, 180)) band_y_r = bm.zeros((indices.shape[0], self.num_module, 180)) grid_r = bm.zeros((indices.shape[0], self.num_module, 20**2)) grid_output = bm.zeros((indices.shape[0], self.num_place)) loc_input_stre = bm.ones((indices.shape[0],)) * loc_input_stre for i, model in enumerate(self.MEC_model_list): def run_single_module(velocity, loc, loc_input_stre): model(velocity=velocity, loc=loc, loc_input_stre=loc_input_stre) return ( model.band_cell_x.band_cells.r.value, model.band_cell_y.band_cells.r.value, model.grid_cell.r.value, model.grid_output.value, ) single_band_x_r, single_band_y_r, single_grid_r, single_grid_output = ( bm.for_loop( run_single_module, velocities, positions, loc_input_stre, ) ) band_x_r = band_x_r.at[:, i, :].set(single_band_x_r) band_y_r = band_y_r.at[:, i, :].set(single_band_y_r) grid_r = grid_r.at[:, i, :].set(single_grid_r) grid_output += single_grid_output grid_output = bm.where(grid_output > 0, grid_output, 0) u_place = bm.where( grid_output > bm.max(grid_output, axis=1, keepdims=True) / 2, grid_output - bm.max(grid_output, axis=1, keepdims=True) / 2, 0, ) place_r = u_place**2 / (1 + bm.sum(u_place**2, axis=1, keepdims=True)) return band_x_r, band_y_r, grid_r, place_r '''