src.canns.analyzer.slow_points.finder¶
Fixed point finder for BrainPy RNN models.
Classes¶
Find and analyze fixed points in RNN dynamics. |
Module Contents¶
- class src.canns.analyzer.slow_points.finder.FixedPointFinder(rnn_model, method='joint', max_iters=5000, tol_q=1e-12, tol_dq=1e-20, lr_init=1.0, lr_factor=0.95, lr_patience=5, lr_cooldown=0, do_compute_jacobians=True, do_decompose_jacobians=True, tol_unique=0.001, do_exclude_distance_outliers=True, outlier_distance_scale=10.0, do_rerun_q_outliers=False, outlier_q_scale=10.0, max_n_unique=np.inf, final_q_threshold=1e-08, dtype='float32', verbose=True, super_verbose=False, n_iters_per_print_update=100)[source]¶
Find and analyze fixed points in RNN dynamics.
This class implements an optimization-based approach to finding fixed points in recurrent neural networks. It uses gradient descent to minimize the objective q = 0.5 * ||x - F(x, u)||^2, where F is the RNN transition function.
The implementation is compatible with BrainPy RNN models and uses JAX for automatic differentiation and optimization.
Initialize the FixedPointFinder.
- Parameters:
rnn_model (brainpy.DynamicalSystem) – A BrainPy RNN model with __call__(inputs, hidden) signature.
method (str) – Optimization method (‘joint’ or ‘sequential’).
max_iters (int) – Maximum optimization iterations.
tol_q (float) – Tolerance for q value convergence.
tol_dq (float) – Tolerance for change in q value.
lr_init (float) – Initial learning rate.
lr_factor (float) – Learning rate reduction factor.
lr_patience (int) – Patience for learning rate scheduler.
lr_cooldown (int) – Cooldown for learning rate scheduler.
do_compute_jacobians (bool) – Whether to compute Jacobians.
do_decompose_jacobians (bool) – Whether to eigendecompose Jacobians.
tol_unique (float) – Tolerance for identifying unique fixed points.
do_exclude_distance_outliers (bool) – Whether to exclude distance outliers.
outlier_distance_scale (float) – Scale for distance outlier detection.
do_rerun_q_outliers (bool) – Whether to rerun optimization on q outliers.
outlier_q_scale (float) – Scale for q outlier detection.
max_n_unique (float) – Maximum number of unique fixed points to keep.
dtype (str) – Data type for computations.
verbose (bool) – Print high-level status updates.
super_verbose (bool) – Print per-iteration updates.
n_iters_per_print_update (int) – Print frequency during optimization.
- find_fixed_points(state_traj, inputs, n_inits=1024, noise_scale=0.0, valid_bxt=None, cond_ids=None)[source]¶
Find fixed points from sampled RNN states.
- Parameters:
state_traj (numpy.ndarray) – [n_batch x n_time x n_states] trajectory of RNN states.
inputs (numpy.ndarray) – [1 x n_inputs] or [n_inits x n_inputs] constant inputs.
n_inits (int) – Number of initial states to sample.
noise_scale (float) – Std dev of Gaussian noise added to sampled states.
valid_bxt (numpy.ndarray | None) – [n_batch x n_time] boolean mask for valid samples.
cond_ids (numpy.ndarray | None) – [n_inits,] condition IDs for each initialization.
- Returns:
FixedPoints object with unique fixed points. all_fps: FixedPoints object with all fixed points before filtering.
- Return type:
unique_fps