src.canns.analyzer.slow_points¶
Fixed point finder for BrainPy RNN models.
This module provides tools for identifying and analyzing fixed points in recurrent neural networks using JAX/BrainPy.
Submodules¶
Classes¶
Find and analyze fixed points in RNN dynamics. |
|
Container for storing and manipulating fixed points. |
Functions¶
|
Load model parameters from a checkpoint file using BrainPy checkpointing. |
|
Plot fixed points and trajectories in 2D using PCA. |
|
Plot fixed points and trajectories in 3D using PCA. |
|
Save model parameters to a checkpoint file using BrainPy checkpointing. |
Package Contents¶
- class src.canns.analyzer.slow_points.FixedPointFinder(rnn_model, method='joint', max_iters=5000, tol_q=1e-12, tol_dq=1e-20, lr_init=1.0, lr_factor=0.95, lr_patience=5, lr_cooldown=0, do_compute_jacobians=True, do_decompose_jacobians=True, tol_unique=0.001, do_exclude_distance_outliers=True, outlier_distance_scale=10.0, do_rerun_q_outliers=False, outlier_q_scale=10.0, max_n_unique=np.inf, final_q_threshold=1e-08, dtype='float32', verbose=True, super_verbose=False, n_iters_per_print_update=100)[source]¶
Find and analyze fixed points in RNN dynamics.
This class implements an optimization-based approach to finding fixed points in recurrent neural networks. It uses gradient descent to minimize the objective q = 0.5 * ||x - F(x, u)||^2, where F is the RNN transition function.
The implementation is compatible with BrainPy RNN models and uses JAX for automatic differentiation and optimization.
Initialize the FixedPointFinder.
- Parameters:
rnn_model (brainpy.DynamicalSystem) – A BrainPy RNN model with __call__(inputs, hidden) signature.
method (str) – Optimization method (‘joint’ or ‘sequential’).
max_iters (int) – Maximum optimization iterations.
tol_q (float) – Tolerance for q value convergence.
tol_dq (float) – Tolerance for change in q value.
lr_init (float) – Initial learning rate.
lr_factor (float) – Learning rate reduction factor.
lr_patience (int) – Patience for learning rate scheduler.
lr_cooldown (int) – Cooldown for learning rate scheduler.
do_compute_jacobians (bool) – Whether to compute Jacobians.
do_decompose_jacobians (bool) – Whether to eigendecompose Jacobians.
tol_unique (float) – Tolerance for identifying unique fixed points.
do_exclude_distance_outliers (bool) – Whether to exclude distance outliers.
outlier_distance_scale (float) – Scale for distance outlier detection.
do_rerun_q_outliers (bool) – Whether to rerun optimization on q outliers.
outlier_q_scale (float) – Scale for q outlier detection.
max_n_unique (float) – Maximum number of unique fixed points to keep.
dtype (str) – Data type for computations.
verbose (bool) – Print high-level status updates.
super_verbose (bool) – Print per-iteration updates.
n_iters_per_print_update (int) – Print frequency during optimization.
- find_fixed_points(state_traj, inputs, n_inits=1024, noise_scale=0.0, valid_bxt=None, cond_ids=None)[source]¶
Find fixed points from sampled RNN states.
- Parameters:
state_traj (numpy.ndarray) – [n_batch x n_time x n_states] trajectory of RNN states.
inputs (numpy.ndarray) – [1 x n_inputs] or [n_inits x n_inputs] constant inputs.
n_inits (int) – Number of initial states to sample.
noise_scale (float) – Std dev of Gaussian noise added to sampled states.
valid_bxt (numpy.ndarray | None) – [n_batch x n_time] boolean mask for valid samples.
cond_ids (numpy.ndarray | None) – [n_inits,] condition IDs for each initialization.
- Returns:
FixedPoints object with unique fixed points. all_fps: FixedPoints object with all fixed points before filtering.
- Return type:
unique_fps
- do_compute_jacobians = True¶
- do_decompose_jacobians = True¶
- do_exclude_distance_outliers = True¶
- do_rerun_q_outliers = False¶
- final_q_threshold¶
- lr_cooldown = 0¶
- lr_factor¶
- lr_init¶
- lr_patience = 5¶
- max_iters = 5000¶
- max_n_unique¶
- method = 'joint'¶
- n_iters_per_print_update = 100¶
- outlier_distance_scale¶
- outlier_q_scale¶
- rng¶
- rnn_model¶
- super_verbose = False¶
- tol_dq¶
- tol_q¶
- tol_unique¶
- verbose = True¶
- class src.canns.analyzer.slow_points.FixedPoints(xstar=None, F_xstar=None, x_init=None, inputs=None, qstar=None, dq=None, n_iters=None, J_xstar=None, dFdu=None, eigval_J_xstar=None, eigvec_J_xstar=None, is_stable=None, cond_id=None, tol_unique=0.001, dtype=np.float32)[source]¶
Container for storing and manipulating fixed points.
This class stores fixed points found by the FixedPointFinder algorithm, along with associated metadata like Jacobians, eigenvalues, and stability.
- xstar¶
[n x n_states] array of fixed point states.
- F_xstar¶
[n x n_states] array of states after one RNN step from xstar.
- x_init¶
[n x n_states] array of initial states used for optimization.
- inputs¶
[n x n_inputs] array of constant inputs during optimization.
- qstar¶
[n,] array of final q values (optimization objective).
- dq¶
[n,] array of change in q at the last optimization step.
- n_iters¶
[n,] array of iteration counts for each optimization.
- J_xstar¶
[n x n_states x n_states] array of Jacobians dF/dx at fixed points.
- dFdu¶
[n x n_states x n_inputs] array of Jacobians dF/du at fixed points.
- eigval_J_xstar¶
[n x n_states] complex array of eigenvalues.
- eigvec_J_xstar¶
[n x n_states x n_states] complex array of eigenvectors.
- is_stable¶
[n,] bool array indicating stability (max |eigenvalue| < 1).
- cond_id¶
[n,] array of condition IDs (optional).
- tol_unique¶
Tolerance for identifying unique fixed points.
- dtype¶
NumPy dtype for data storage.
Initialize a FixedPoints object.
- Parameters:
xstar (numpy.ndarray | None) – Fixed point states [n x n_states].
F_xstar (numpy.ndarray | None) – States after one RNN step [n x n_states].
x_init (numpy.ndarray | None) – Initial states [n x n_states].
inputs (numpy.ndarray | None) – Constant inputs [n x n_inputs].
qstar (numpy.ndarray | None) – Final q values [n,].
dq (numpy.ndarray | None) – Change in q at last step [n,].
n_iters (numpy.ndarray | None) – Iteration counts [n,].
J_xstar (numpy.ndarray | None) – Jacobians dF/dx [n x n_states x n_states].
dFdu (numpy.ndarray | None) – Jacobians dF/du [n x n_states x n_inputs].
eigval_J_xstar (numpy.ndarray | None) – Eigenvalues [n x n_states] (complex).
eigvec_J_xstar (numpy.ndarray | None) – Eigenvectors [n x n_states x n_states] (complex).
is_stable (numpy.ndarray | None) – Stability flags [n,].
cond_id (numpy.ndarray | None) – Condition IDs [n,].
tol_unique (float) – Tolerance for uniqueness detection.
dtype – NumPy data type for storage.
- __getitem__(idx)[source]¶
Index into the fixed points.
- Parameters:
idx – Integer index, slice, or array of indices.
- Returns:
A new FixedPoints object containing the indexed subset.
- decompose_jacobians(verbose=False)[source]¶
Compute eigendecomposition of Jacobians and determine stability.
Computes eigenvalues and eigenvectors for self.J_xstar and determines stability based on whether max |eigenvalue| < 1.
- Parameters:
verbose (bool) – Whether to print status messages.
- get_unique()[source]¶
Identify and return unique fixed points.
Uniqueness is determined by Euclidean distance in the concatenated (xstar, inputs) space. Among duplicates, keeps the one with lowest qstar.
- Returns:
A new FixedPoints object containing only unique fixed points.
- F_xstar = None¶
- J_xstar = None¶
- cond_id = None¶
- dFdu = None¶
- dq = None¶
- dtype¶
- eigval_J_xstar = None¶
- eigvec_J_xstar = None¶
- inputs = None¶
- is_stable = None¶
- n_iters = None¶
- qstar = None¶
- tol_unique¶
- x_init = None¶
- xstar = None¶
- src.canns.analyzer.slow_points.load_checkpoint(model, filepath)[source]¶
Load model parameters from a checkpoint file using BrainPy checkpointing.
- Parameters:
model (brainpy.DynamicalSystem) – BrainPy model to load parameters into.
filepath (str) – Path to the checkpoint file.
- Returns:
True if checkpoint was loaded successfully, False otherwise.
- Return type:
Example
>>> from canns.analyzer.slow_points import load_checkpoint >>> if load_checkpoint(rnn, "my_model.msgpack"): ... print("Loaded successfully") ... else: ... print("No checkpoint found") Loaded checkpoint from: my_model.msgpack Loaded successfully
- src.canns.analyzer.slow_points.plot_fixed_points_2d(fixed_points, state_traj, config=None, plot_batch_idx=None, plot_start_time=0)[source]¶
Plot fixed points and trajectories in 2D using PCA.
- Parameters:
fixed_points (src.canns.analyzer.slow_points.fixed_points.FixedPoints) – FixedPoints object containing analysis results.
state_traj (numpy.ndarray) – State trajectories [n_batch x n_time x n_states].
config (src.canns.analyzer.plotting.config.PlotConfig | None) – Plot configuration. If None, uses default config.
plot_batch_idx (list[int] | None) – Batch indices to plot trajectories. If None, plots first 30.
plot_start_time (int) – Starting time index for trajectory plotting.
- Returns:
matplotlib Figure object.
- Return type:
matplotlib.figure.Figure
Example
>>> from canns.analyzer.slow_points import plot_fixed_points_2d, FixedPoints >>> from canns.analyzer.plotting import PlotConfig >>> config = PlotConfig( ... title="Fixed Points Analysis", ... figsize=(10, 8), ... save_path="fps_2d.png" ... ) >>> fig = plot_fixed_points_2d(unique_fps, hiddens, config=config)
- src.canns.analyzer.slow_points.plot_fixed_points_3d(fixed_points, state_traj, config=None, plot_batch_idx=None, plot_start_time=0)[source]¶
Plot fixed points and trajectories in 3D using PCA.
- Parameters:
fixed_points (src.canns.analyzer.slow_points.fixed_points.FixedPoints) – FixedPoints object containing analysis results.
state_traj (numpy.ndarray) – State trajectories [n_batch x n_time x n_states].
config (src.canns.analyzer.plotting.config.PlotConfig | None) – Plot configuration. If None, uses default config.
plot_batch_idx (list[int] | None) – Batch indices to plot trajectories. If None, plots first 30.
plot_start_time (int) – Starting time index for trajectory plotting.
- Returns:
matplotlib Figure object.
- Return type:
matplotlib.figure.Figure
Example
>>> from canns.analyzer.slow_points import plot_fixed_points_3d, FixedPoints >>> from canns.analyzer.plotting import PlotConfig >>> config = PlotConfig( ... title="Fixed Points 3D", ... figsize=(12, 10), ... save_path="fps_3d.png" ... ) >>> fig = plot_fixed_points_3d(unique_fps, hiddens, config=config)
- src.canns.analyzer.slow_points.save_checkpoint(model, filepath)[source]¶
Save model parameters to a checkpoint file using BrainPy checkpointing.
- Parameters:
model (brainpy.DynamicalSystem) – BrainPy model to save.
filepath (str) – Path to save the checkpoint file.
Example
>>> from canns.analyzer.slow_points import save_checkpoint >>> save_checkpoint(rnn, "my_model.msgpack") Saved checkpoint to: my_model.msgpack