src.canns.analyzer.slow_points.finder

Fixed point finder for BrainPy RNN models.

Classes

FixedPointFinder

Find and analyze fixed points in RNN dynamics.

Module Contents

class src.canns.analyzer.slow_points.finder.FixedPointFinder(rnn_model, method='joint', max_iters=5000, tol_q=1e-12, tol_dq=1e-20, lr_init=1.0, lr_factor=0.95, lr_patience=5, lr_cooldown=0, do_compute_jacobians=True, do_decompose_jacobians=True, tol_unique=0.001, do_exclude_distance_outliers=True, outlier_distance_scale=10.0, do_rerun_q_outliers=False, outlier_q_scale=10.0, max_n_unique=np.inf, final_q_threshold=1e-08, dtype='float32', verbose=True, super_verbose=False, n_iters_per_print_update=100)[source]

Find and analyze fixed points in RNN dynamics.

This class implements an optimization-based approach to finding fixed points in recurrent neural networks. It uses gradient descent to minimize the objective q = 0.5 * ||x - F(x, u)||^2, where F is the RNN transition function.

The implementation is compatible with BrainPy RNN models and uses JAX for automatic differentiation and optimization.

Initialize the FixedPointFinder.

Parameters:
  • rnn_model (brainpy.DynamicalSystem) – A BrainPy RNN model with __call__(inputs, hidden) signature.

  • method (str) – Optimization method (‘joint’ or ‘sequential’).

  • max_iters (int) – Maximum optimization iterations.

  • tol_q (float) – Tolerance for q value convergence.

  • tol_dq (float) – Tolerance for change in q value.

  • lr_init (float) – Initial learning rate.

  • lr_factor (float) – Learning rate reduction factor.

  • lr_patience (int) – Patience for learning rate scheduler.

  • lr_cooldown (int) – Cooldown for learning rate scheduler.

  • do_compute_jacobians (bool) – Whether to compute Jacobians.

  • do_decompose_jacobians (bool) – Whether to eigendecompose Jacobians.

  • tol_unique (float) – Tolerance for identifying unique fixed points.

  • do_exclude_distance_outliers (bool) – Whether to exclude distance outliers.

  • outlier_distance_scale (float) – Scale for distance outlier detection.

  • do_rerun_q_outliers (bool) – Whether to rerun optimization on q outliers.

  • outlier_q_scale (float) – Scale for q outlier detection.

  • max_n_unique (float) – Maximum number of unique fixed points to keep.

  • dtype (str) – Data type for computations.

  • verbose (bool) – Print high-level status updates.

  • super_verbose (bool) – Print per-iteration updates.

  • n_iters_per_print_update (int) – Print frequency during optimization.

find_fixed_points(state_traj, inputs, n_inits=1024, noise_scale=0.0, valid_bxt=None, cond_ids=None)[source]

Find fixed points from sampled RNN states.

Parameters:
  • state_traj (numpy.ndarray) – [n_batch x n_time x n_states] trajectory of RNN states.

  • inputs (numpy.ndarray) – [1 x n_inputs] or [n_inits x n_inputs] constant inputs.

  • n_inits (int) – Number of initial states to sample.

  • noise_scale (float) – Std dev of Gaussian noise added to sampled states.

  • valid_bxt (numpy.ndarray | None) – [n_batch x n_time] boolean mask for valid samples.

  • cond_ids (numpy.ndarray | None) – [n_inits,] condition IDs for each initialization.

Returns:

FixedPoints object with unique fixed points. all_fps: FixedPoints object with all fixed points before filtering.

Return type:

unique_fps

do_compute_jacobians = True[source]
do_decompose_jacobians = True[source]
do_exclude_distance_outliers = True[source]
do_rerun_q_outliers = False[source]
final_q_threshold[source]
lr_cooldown = 0[source]
lr_factor[source]
lr_init[source]
lr_patience = 5[source]
max_iters = 5000[source]
max_n_unique[source]
method = 'joint'[source]
n_iters_per_print_update = 100[source]
outlier_distance_scale[source]
outlier_q_scale[source]
rng[source]
rnn_model[source]
super_verbose = False[source]
tol_dq[source]
tol_q[source]
tol_unique[source]
verbose = True[source]