src.canns.analyzer.slow_points

Fixed point finder for BrainPy RNN models.

This module provides tools for identifying and analyzing fixed points in recurrent neural networks using JAX/BrainPy.

Submodules

Classes

FixedPointFinder

Find and analyze fixed points in RNN dynamics.

FixedPoints

Container for storing and manipulating fixed points.

Functions

load_checkpoint(model, filepath)

Load model parameters from a checkpoint file using BrainPy checkpointing.

plot_fixed_points_2d(fixed_points, state_traj[, ...])

Plot fixed points and trajectories in 2D using PCA.

plot_fixed_points_3d(fixed_points, state_traj[, ...])

Plot fixed points and trajectories in 3D using PCA.

save_checkpoint(model, filepath)

Save model parameters to a checkpoint file using BrainPy checkpointing.

Package Contents

class src.canns.analyzer.slow_points.FixedPointFinder(rnn_model, method='joint', max_iters=5000, tol_q=1e-12, tol_dq=1e-20, lr_init=1.0, lr_factor=0.95, lr_patience=5, lr_cooldown=0, do_compute_jacobians=True, do_decompose_jacobians=True, tol_unique=0.001, do_exclude_distance_outliers=True, outlier_distance_scale=10.0, do_rerun_q_outliers=False, outlier_q_scale=10.0, max_n_unique=np.inf, final_q_threshold=1e-08, dtype='float32', verbose=True, super_verbose=False, n_iters_per_print_update=100)[source]

Find and analyze fixed points in RNN dynamics.

This class implements an optimization-based approach to finding fixed points in recurrent neural networks. It uses gradient descent to minimize the objective q = 0.5 * ||x - F(x, u)||^2, where F is the RNN transition function.

The implementation is compatible with BrainPy RNN models and uses JAX for automatic differentiation and optimization.

Initialize the FixedPointFinder.

Parameters:
  • rnn_model (brainpy.DynamicalSystem) – A BrainPy RNN model with __call__(inputs, hidden) signature.

  • method (str) – Optimization method (‘joint’ or ‘sequential’).

  • max_iters (int) – Maximum optimization iterations.

  • tol_q (float) – Tolerance for q value convergence.

  • tol_dq (float) – Tolerance for change in q value.

  • lr_init (float) – Initial learning rate.

  • lr_factor (float) – Learning rate reduction factor.

  • lr_patience (int) – Patience for learning rate scheduler.

  • lr_cooldown (int) – Cooldown for learning rate scheduler.

  • do_compute_jacobians (bool) – Whether to compute Jacobians.

  • do_decompose_jacobians (bool) – Whether to eigendecompose Jacobians.

  • tol_unique (float) – Tolerance for identifying unique fixed points.

  • do_exclude_distance_outliers (bool) – Whether to exclude distance outliers.

  • outlier_distance_scale (float) – Scale for distance outlier detection.

  • do_rerun_q_outliers (bool) – Whether to rerun optimization on q outliers.

  • outlier_q_scale (float) – Scale for q outlier detection.

  • max_n_unique (float) – Maximum number of unique fixed points to keep.

  • dtype (str) – Data type for computations.

  • verbose (bool) – Print high-level status updates.

  • super_verbose (bool) – Print per-iteration updates.

  • n_iters_per_print_update (int) – Print frequency during optimization.

find_fixed_points(state_traj, inputs, n_inits=1024, noise_scale=0.0, valid_bxt=None, cond_ids=None)[source]

Find fixed points from sampled RNN states.

Parameters:
  • state_traj (numpy.ndarray) – [n_batch x n_time x n_states] trajectory of RNN states.

  • inputs (numpy.ndarray) – [1 x n_inputs] or [n_inits x n_inputs] constant inputs.

  • n_inits (int) – Number of initial states to sample.

  • noise_scale (float) – Std dev of Gaussian noise added to sampled states.

  • valid_bxt (numpy.ndarray | None) – [n_batch x n_time] boolean mask for valid samples.

  • cond_ids (numpy.ndarray | None) – [n_inits,] condition IDs for each initialization.

Returns:

FixedPoints object with unique fixed points. all_fps: FixedPoints object with all fixed points before filtering.

Return type:

unique_fps

do_compute_jacobians = True
do_decompose_jacobians = True
do_exclude_distance_outliers = True
do_rerun_q_outliers = False
final_q_threshold
lr_cooldown = 0
lr_factor
lr_init
lr_patience = 5
max_iters = 5000
max_n_unique
method = 'joint'
n_iters_per_print_update = 100
outlier_distance_scale
outlier_q_scale
rng
rnn_model
super_verbose = False
tol_dq
tol_q
tol_unique
verbose = True
class src.canns.analyzer.slow_points.FixedPoints(xstar=None, F_xstar=None, x_init=None, inputs=None, qstar=None, dq=None, n_iters=None, J_xstar=None, dFdu=None, eigval_J_xstar=None, eigvec_J_xstar=None, is_stable=None, cond_id=None, tol_unique=0.001, dtype=np.float32)[source]

Container for storing and manipulating fixed points.

This class stores fixed points found by the FixedPointFinder algorithm, along with associated metadata like Jacobians, eigenvalues, and stability.

xstar

[n x n_states] array of fixed point states.

F_xstar

[n x n_states] array of states after one RNN step from xstar.

x_init

[n x n_states] array of initial states used for optimization.

inputs

[n x n_inputs] array of constant inputs during optimization.

qstar

[n,] array of final q values (optimization objective).

dq

[n,] array of change in q at the last optimization step.

n_iters

[n,] array of iteration counts for each optimization.

J_xstar

[n x n_states x n_states] array of Jacobians dF/dx at fixed points.

dFdu

[n x n_states x n_inputs] array of Jacobians dF/du at fixed points.

eigval_J_xstar

[n x n_states] complex array of eigenvalues.

eigvec_J_xstar

[n x n_states x n_states] complex array of eigenvectors.

is_stable

[n,] bool array indicating stability (max |eigenvalue| < 1).

cond_id

[n,] array of condition IDs (optional).

tol_unique

Tolerance for identifying unique fixed points.

dtype

NumPy dtype for data storage.

Initialize a FixedPoints object.

Parameters:
  • xstar (numpy.ndarray | None) – Fixed point states [n x n_states].

  • F_xstar (numpy.ndarray | None) – States after one RNN step [n x n_states].

  • x_init (numpy.ndarray | None) – Initial states [n x n_states].

  • inputs (numpy.ndarray | None) – Constant inputs [n x n_inputs].

  • qstar (numpy.ndarray | None) – Final q values [n,].

  • dq (numpy.ndarray | None) – Change in q at last step [n,].

  • n_iters (numpy.ndarray | None) – Iteration counts [n,].

  • J_xstar (numpy.ndarray | None) – Jacobians dF/dx [n x n_states x n_states].

  • dFdu (numpy.ndarray | None) – Jacobians dF/du [n x n_states x n_inputs].

  • eigval_J_xstar (numpy.ndarray | None) – Eigenvalues [n x n_states] (complex).

  • eigvec_J_xstar (numpy.ndarray | None) – Eigenvectors [n x n_states x n_states] (complex).

  • is_stable (numpy.ndarray | None) – Stability flags [n,].

  • cond_id (numpy.ndarray | None) – Condition IDs [n,].

  • tol_unique (float) – Tolerance for uniqueness detection.

  • dtype – NumPy data type for storage.

__getitem__(idx)[source]

Index into the fixed points.

Parameters:

idx – Integer index, slice, or array of indices.

Returns:

A new FixedPoints object containing the indexed subset.

__len__()[source]

Return the number of fixed points.

decompose_jacobians(verbose=False)[source]

Compute eigendecomposition of Jacobians and determine stability.

Computes eigenvalues and eigenvectors for self.J_xstar and determines stability based on whether max |eigenvalue| < 1.

Parameters:

verbose (bool) – Whether to print status messages.

get_unique()[source]

Identify and return unique fixed points.

Uniqueness is determined by Euclidean distance in the concatenated (xstar, inputs) space. Among duplicates, keeps the one with lowest qstar.

Returns:

A new FixedPoints object containing only unique fixed points.

print_summary()[source]

Print a summary of the fixed points.

F_xstar = None
J_xstar = None
cond_id = None
dFdu = None
dq = None
dtype
eigval_J_xstar = None
eigvec_J_xstar = None
property has_decomposed_jacobians: bool

Check if Jacobians have been decomposed.

inputs = None
is_stable = None
n_iters = None
qstar = None
tol_unique
x_init = None
xstar = None
src.canns.analyzer.slow_points.load_checkpoint(model, filepath)[source]

Load model parameters from a checkpoint file using BrainPy checkpointing.

Parameters:
  • model (brainpy.DynamicalSystem) – BrainPy model to load parameters into.

  • filepath (str) – Path to the checkpoint file.

Returns:

True if checkpoint was loaded successfully, False otherwise.

Return type:

bool

Example

>>> from canns.analyzer.slow_points import load_checkpoint
>>> if load_checkpoint(rnn, "my_model.msgpack"):
...     print("Loaded successfully")
... else:
...     print("No checkpoint found")
Loaded checkpoint from: my_model.msgpack
Loaded successfully
src.canns.analyzer.slow_points.plot_fixed_points_2d(fixed_points, state_traj, config=None, plot_batch_idx=None, plot_start_time=0)[source]

Plot fixed points and trajectories in 2D using PCA.

Parameters:
Returns:

matplotlib Figure object.

Return type:

matplotlib.figure.Figure

Example

>>> from canns.analyzer.slow_points import plot_fixed_points_2d, FixedPoints
>>> from canns.analyzer.plotting import PlotConfig
>>> config = PlotConfig(
...     title="Fixed Points Analysis",
...     figsize=(10, 8),
...     save_path="fps_2d.png"
... )
>>> fig = plot_fixed_points_2d(unique_fps, hiddens, config=config)
src.canns.analyzer.slow_points.plot_fixed_points_3d(fixed_points, state_traj, config=None, plot_batch_idx=None, plot_start_time=0)[source]

Plot fixed points and trajectories in 3D using PCA.

Parameters:
Returns:

matplotlib Figure object.

Return type:

matplotlib.figure.Figure

Example

>>> from canns.analyzer.slow_points import plot_fixed_points_3d, FixedPoints
>>> from canns.analyzer.plotting import PlotConfig
>>> config = PlotConfig(
...     title="Fixed Points 3D",
...     figsize=(12, 10),
...     save_path="fps_3d.png"
... )
>>> fig = plot_fixed_points_3d(unique_fps, hiddens, config=config)
src.canns.analyzer.slow_points.save_checkpoint(model, filepath)[source]

Save model parameters to a checkpoint file using BrainPy checkpointing.

Parameters:
  • model (brainpy.DynamicalSystem) – BrainPy model to save.

  • filepath (str) – Path to save the checkpoint file.

Example

>>> from canns.analyzer.slow_points import save_checkpoint
>>> save_checkpoint(rnn, "my_model.msgpack")
Saved checkpoint to: my_model.msgpack