Source code for src.canns.analyzer.slow_points.finder

"""Fixed point finder for BrainPy RNN models."""

import time

import brainpy as bp
import brainpy.math as bm
import jax
import jax.numpy as jnp
import numpy as np

from .fixed_points import FixedPoints


[docs] class FixedPointFinder: """Find and analyze fixed points in RNN dynamics. This class implements an optimization-based approach to finding fixed points in recurrent neural networks. It uses gradient descent to minimize the objective q = 0.5 * ||x - F(x, u)||^2, where F is the RNN transition function. The implementation is compatible with BrainPy RNN models and uses JAX for automatic differentiation and optimization. """ def __init__( self, rnn_model: bp.DynamicalSystem, method: str = "joint", max_iters: int = 5000, tol_q: float = 1e-12, tol_dq: float = 1e-20, lr_init: float = 1.0, lr_factor: float = 0.95, lr_patience: int = 5, lr_cooldown: int = 0, do_compute_jacobians: bool = True, do_decompose_jacobians: bool = True, tol_unique: float = 1e-3, do_exclude_distance_outliers: bool = True, outlier_distance_scale: float = 10.0, do_rerun_q_outliers: bool = False, outlier_q_scale: float = 10.0, max_n_unique: float = np.inf, final_q_threshold: float = 1e-8, dtype: str = "float32", verbose: bool = True, super_verbose: bool = False, n_iters_per_print_update: int = 100, ): """Initialize the FixedPointFinder. Args: rnn_model: A BrainPy RNN model with __call__(inputs, hidden) signature. method: Optimization method ('joint' or 'sequential'). max_iters: Maximum optimization iterations. tol_q: Tolerance for q value convergence. tol_dq: Tolerance for change in q value. lr_init: Initial learning rate. lr_factor: Learning rate reduction factor. lr_patience: Patience for learning rate scheduler. lr_cooldown: Cooldown for learning rate scheduler. do_compute_jacobians: Whether to compute Jacobians. do_decompose_jacobians: Whether to eigendecompose Jacobians. tol_unique: Tolerance for identifying unique fixed points. do_exclude_distance_outliers: Whether to exclude distance outliers. outlier_distance_scale: Scale for distance outlier detection. do_rerun_q_outliers: Whether to rerun optimization on q outliers. outlier_q_scale: Scale for q outlier detection. max_n_unique: Maximum number of unique fixed points to keep. dtype: Data type for computations. verbose: Print high-level status updates. super_verbose: Print per-iteration updates. n_iters_per_print_update: Print frequency during optimization. """
[docs] self.rnn_model = rnn_model
[docs] self.method = method
[docs] self.max_iters = int(max_iters)
[docs] self.tol_q = float(tol_q)
[docs] self.tol_dq = float(tol_dq)
[docs] self.lr_init = float(lr_init)
[docs] self.lr_factor = float(lr_factor)
[docs] self.lr_patience = int(lr_patience)
[docs] self.lr_cooldown = int(lr_cooldown)
[docs] self.do_compute_jacobians = bool(do_compute_jacobians)
[docs] self.do_decompose_jacobians = bool(do_decompose_jacobians)
[docs] self.tol_unique = float(tol_unique)
[docs] self.do_exclude_distance_outliers = bool(do_exclude_distance_outliers)
[docs] self.outlier_distance_scale = float(outlier_distance_scale)
[docs] self.do_rerun_q_outliers = bool(do_rerun_q_outliers)
[docs] self.outlier_q_scale = float(outlier_q_scale)
[docs] self.max_n_unique = max_n_unique
[docs] self.final_q_threshold = float(final_q_threshold)
[docs] self.verbose = bool(verbose)
[docs] self.super_verbose = bool(super_verbose)
[docs] self.n_iters_per_print_update = int(n_iters_per_print_update)
# Random number generator
[docs] self.rng = np.random.RandomState(0)
# Data type if dtype == "float32": self.np_dtype = np.float32 self.jax_dtype = jnp.float32 elif dtype == "float64": self.np_dtype = np.float64 self.jax_dtype = jnp.float64 else: raise ValueError(f"Unsupported dtype: {dtype}")
[docs] def find_fixed_points( self, state_traj: np.ndarray, inputs: np.ndarray, n_inits: int = 1024, noise_scale: float = 0.0, valid_bxt: np.ndarray | None = None, cond_ids: np.ndarray | None = None, ) -> tuple[FixedPoints, FixedPoints]: """Find fixed points from sampled RNN states. Args: state_traj: [n_batch x n_time x n_states] trajectory of RNN states. inputs: [1 x n_inputs] or [n_inits x n_inputs] constant inputs. n_inits: Number of initial states to sample. noise_scale: Std dev of Gaussian noise added to sampled states. valid_bxt: [n_batch x n_time] boolean mask for valid samples. cond_ids: [n_inits,] condition IDs for each initialization. Returns: unique_fps: FixedPoints object with unique fixed points. all_fps: FixedPoints object with all fixed points before filtering. """ self._print_if_verbose(f"\nSearching for fixed points from {n_inits} initial states.\n") # Sample initial states initial_states = self._sample_states(state_traj, n_inits, valid_bxt, noise_scale) # Prepare inputs if inputs.shape[0] == 1: inputs_nxd = np.tile(inputs, [n_inits, 1]) elif inputs.shape[0] == n_inits: inputs_nxd = inputs else: raise ValueError( f"Incompatible inputs shape: {inputs.shape}. " f"Expected [1, n_inputs] or [{n_inits}, n_inputs]." ) # Run optimization if self.method == "joint": # Warn if n_inits is large for joint optimization LARGE_N_INITS_THRESHOLD = 1000 if n_inits > LARGE_N_INITS_THRESHOLD: import warnings warnings.warn( f"Joint optimization with n_inits={n_inits} may be inefficient and use excessive memory. " f"Consider using sequential optimization or reducing n_inits.", stacklevel=2, ) all_fps = self._run_joint_optimization(initial_states, inputs_nxd, cond_ids) elif self.method == "sequential": all_fps = self._run_sequential_optimizations(initial_states, inputs_nxd, cond_ids) else: raise ValueError(f"Unsupported method: {self.method}. Must be 'joint' or 'sequential'.") # Filter unique fixed points unique_fps = all_fps.get_unique() self._print_if_verbose(f"\tIdentified {unique_fps.n} unique fixed points.") # Exclude distance outliers if self.do_exclude_distance_outliers and unique_fps.n > 0: unique_fps = self._exclude_distance_outliers(unique_fps, initial_states) # Rerun q outliers if self.do_rerun_q_outliers and unique_fps.n > 0: unique_fps = self._run_additional_iterations_on_outliers(unique_fps) unique_fps = unique_fps.get_unique() # Limit number of unique fixed points if unique_fps.n > self.max_n_unique: self._print_if_verbose( f"\tSelecting top {int(self.max_n_unique)} unique fixed points by qstar." ) # Sort fixed points by qstar (ascending = better convergence) idx_sorted = np.argsort(unique_fps.qstar) idx_keep = idx_sorted[: int(self.max_n_unique)] unique_fps = unique_fps[idx_keep] # Compute Jacobians if self.do_compute_jacobians and unique_fps.n > 0: self._print_if_verbose( f"\tComputing recurrent Jacobian at {unique_fps.n} unique fixed points." ) unique_fps.J_xstar = self._compute_recurrent_jacobians(unique_fps) self._print_if_verbose( f"\tComputing input Jacobian at {unique_fps.n} unique fixed points." ) unique_fps.dFdu = self._compute_input_jacobians(unique_fps) if self.do_decompose_jacobians: unique_fps.decompose_jacobians(verbose=self.verbose) # Set the conditions for final filtering if self.final_q_threshold > 0 and unique_fps.n > 0: self._print_if_verbose( f"\tApplying final q-value filter (q < {self.final_q_threshold:.1e})..." ) n_before_filter = unique_fps.n idx_keep = np.where(unique_fps.qstar < self.final_q_threshold)[0] unique_fps = unique_fps[idx_keep] n_after_filter = unique_fps.n n_discarded = n_before_filter - n_after_filter if self.verbose and n_discarded > 0: self._print_if_verbose(f"\t\tExcluded {n_discarded} low-quality fixed points.") self._print_if_verbose(f"\t\t{n_after_filter} high-quality fixed points remain.") self._print_if_verbose("\tFixed point finding complete.\n") return unique_fps, all_fps
def _sample_states( self, state_traj: np.ndarray, n_inits: int, valid_bxt: np.ndarray | None, noise_scale: float, ) -> np.ndarray: """Sample initial states from trajectory. Args: state_traj: [n_batch x n_time x n_states] state trajectory. n_inits: Number of samples to draw. valid_bxt: [n_batch x n_time] boolean mask. noise_scale: Std dev of Gaussian noise. Returns: [n_inits x n_states] sampled initial states. """ n_batch, n_time, n_states = state_traj.shape # Create valid mask if valid_bxt is None: valid_bxt = np.ones((n_batch, n_time), dtype=bool) else: assert valid_bxt.shape == (n_batch, n_time), ( f"valid_bxt shape {valid_bxt.shape} does not match expected ({n_batch}, {n_time})" ) # Sample trial and time indices trial_idx, time_idx = np.nonzero(valid_bxt) max_sample_index = len(trial_idx) # Sample without replacement if possible, otherwise allow duplicates if n_inits <= max_sample_index: sample_indices = self.rng.choice(max_sample_index, size=n_inits, replace=False) else: # If we need more samples than available, allow duplicates sample_indices = self.rng.randint(max_sample_index, size=n_inits) # Draw samples states = np.zeros([n_inits, n_states], dtype=self.np_dtype) for i, idx in enumerate(sample_indices): t_idx = trial_idx[idx] time_idx_i = time_idx[idx] states[i, :] = state_traj[t_idx, time_idx_i, :] # Add noise if noise_scale > 0: states += noise_scale * self.rng.randn(n_inits, n_states).astype(self.np_dtype) assert not np.any(np.isnan(states)), "Detected NaNs in sampled states." return states def _run_joint_optimization( self, initial_states: np.ndarray, inputs: np.ndarray, cond_ids: np.ndarray | None, ) -> FixedPoints: """Run joint optimization over all initial states. Args: initial_states: [n x n_states] initial states. inputs: [n x n_inputs] constant inputs. cond_ids: [n,] condition IDs. Returns: FixedPoints object with optimization results. """ self._print_if_verbose("\tFinding fixed points via joint optimization.") n, n_states = initial_states.shape _, n_inputs = inputs.shape # Convert to JAX arrays x_init = jnp.array(initial_states, dtype=self.jax_dtype) u = jnp.array(inputs, dtype=self.jax_dtype) # Create optimization variables as BrainPy Variable x_state = bm.Variable(x_init) # Create optimizer optimizer = bp.optim.Adam(lr=self.lr_init) optimizer.register_train_vars({"x": x_state}) # Track learning rate manually (simplified scheduler) current_lr = self.lr_init lr_patience_counter = 0 lr_cooldown_counter = 0 best_q = float("inf") # Optimization loop iter_count = 0 q_prev = jnp.full(n, float("nan")) t_start = time.time() while True: iter_count += 1 # Get current x x_current = x_state.value # Compute F(x) F_x = self._compute_F(x_current, u) # Compute q = 0.5 * ||x - F(x)||^2 dx = x_current - F_x q = 0.5 * jnp.sum(dx**2, axis=1) q_mean = jnp.mean(q) dq = jnp.abs(q - q_prev) # Compute gradients def loss_fn(): x_opt = x_state.value F_x_opt = self._compute_F(x_opt, u) dx_opt = x_opt - F_x_opt return jnp.mean(0.5 * jnp.sum(dx_opt**2, axis=1)) grads_raw = bm.grad(loss_fn, grad_vars=x_state)() # Wrap gradients in dictionary for optimizer grads = {"x": grads_raw} # Update optimizer.update(grads) # Manual learning rate scheduling (simplified) if iter_count > 1: if q_mean < best_q: best_q = float(q_mean) lr_patience_counter = 0 else: if lr_cooldown_counter == 0: lr_patience_counter += 1 if lr_patience_counter >= self.lr_patience: current_lr *= self.lr_factor optimizer.lr.value = current_lr lr_patience_counter = 0 lr_cooldown_counter = self.lr_cooldown else: lr_cooldown_counter -= 1 # Print update if self.super_verbose and iter_count % self.n_iters_per_print_update == 0: self._print_iter_update( iter_count, t_start, np.array(q), np.array(dq), current_lr, ) # Check convergence if iter_count > 1 and np.all( np.logical_or( np.array(dq) < self.tol_dq * current_lr, np.array(q) < self.tol_q, ) ): self._print_if_verbose("\tOptimization complete to desired tolerance.") break if iter_count >= self.max_iters: self._print_if_verbose("\tMaximum iteration count reached. Terminating.") break q_prev = q # Final print if self.verbose: self._print_iter_update( iter_count, t_start, np.array(q), np.array(dq), current_lr, is_final=True, ) # Extract results xstar = np.array(x_state.value, dtype=self.np_dtype) F_xstar = np.array(F_x, dtype=self.np_dtype) qstar = np.array(q, dtype=self.np_dtype) dq_final = np.array(dq, dtype=self.np_dtype) n_iters = np.full(n, iter_count, dtype=np.int32) return FixedPoints( xstar=xstar, F_xstar=F_xstar, x_init=initial_states.astype(self.np_dtype), inputs=inputs.astype(self.np_dtype), qstar=qstar, dq=dq_final, n_iters=n_iters, cond_id=cond_ids, tol_unique=self.tol_unique, dtype=self.np_dtype, ) def _run_sequential_optimizations( self, initial_states: np.ndarray, inputs: np.ndarray, cond_ids: np.ndarray | None, ) -> FixedPoints: """Run sequential optimizations, one initial state at a time. Args: initial_states: [n x n_states] initial states. inputs: [n x n_inputs] constant inputs. cond_ids: [n,] condition IDs. Returns: FixedPoints object with concatenated results. """ self._print_if_verbose("\tFinding fixed points via sequential optimizations...") fps_list = [] n_inits = initial_states.shape[0] for i in range(n_inits): self._print_if_verbose(f"\n\tInitialization {i + 1} of {n_inits}:") cond_id_i = None if cond_ids is None else cond_ids[i : i + 1] fps_i = self._run_joint_optimization( initial_states[i : i + 1, :], inputs[i : i + 1, :], cond_id_i, ) fps_list.append(fps_i) # Concatenate results return self._concatenate_fps(fps_list) def _compute_F(self, x: jnp.ndarray, u: jnp.ndarray) -> jnp.ndarray: """Compute F(x, u) = next hidden state. Args: x: [n x n_states] current hidden states. u: [n x n_inputs] inputs. Returns: [n x n_states] next hidden states. """ # Assume the RNN model has signature: output, h_next = model(input, h) # We need to expand dims to add time dimension u_expanded = jnp.expand_dims(u, axis=1) # [n x 1 x n_inputs] # Call the model _, h_next = self.rnn_model(u_expanded, x) return h_next def _compute_recurrent_jacobians(self, fps: FixedPoints) -> np.ndarray: """Compute Jacobian dF/dx at fixed points. Args: fps: FixedPoints object. Returns: [n x n_states x n_states] Jacobian matrices. """ xstar = jnp.array(fps.xstar, dtype=self.jax_dtype) inputs_jax = jnp.array(fps.inputs, dtype=self.jax_dtype) def F_batched(x): """Compute F(x) for all x in batch.""" return self._compute_F(x, inputs_jax) # Use JAX vmap + jacrev for efficient batched Jacobian computation def jacobian_single(x_i, u_i): """Compute Jacobian for a single fixed point.""" def F_single(x): return self._compute_F(x[None, :], u_i[None, :])[0] return jax.jacrev(F_single)(x_i) jacobian_batched = jax.vmap(jacobian_single)(xstar, inputs_jax) return np.array(jacobian_batched, dtype=self.np_dtype) def _compute_input_jacobians(self, fps: FixedPoints) -> np.ndarray: """Compute Jacobian dF/du at fixed points. Args: fps: FixedPoints object. Returns: [n x n_states x n_inputs] Jacobian matrices. """ xstar = jnp.array(fps.xstar, dtype=self.jax_dtype) inputs_jax = jnp.array(fps.inputs, dtype=self.jax_dtype) def jacobian_single(x_i, u_i): """Compute input Jacobian for a single fixed point.""" def F_single_u(u): return self._compute_F(x_i[None, :], u[None, :])[0] return jax.jacrev(F_single_u)(u_i) jacobian_batched = jax.vmap(jacobian_single)(xstar, inputs_jax) return np.array(jacobian_batched, dtype=self.np_dtype) def _exclude_distance_outliers( self, fps: FixedPoints, initial_states: np.ndarray ) -> FixedPoints: """Exclude fixed points that are far from initial states. Args: fps: FixedPoints object. initial_states: [n x n_states] initial states. Returns: FixedPoints object with outliers removed. """ centroid = np.mean(initial_states, axis=0) init_dists = np.linalg.norm(initial_states - centroid, axis=1) avg_init_dist = np.mean(init_dists) + 1e-12 fps_dists = np.linalg.norm(fps.xstar - centroid, axis=1) scaled_fps_dists = fps_dists / avg_init_dist idx_keep = np.where(scaled_fps_dists < self.outlier_distance_scale)[0] n_excluded = fps.n - len(idx_keep) if self.verbose and n_excluded > 0: print(f"\t\tExcluded {n_excluded} distance outliers (of {fps.n}).") return fps[idx_keep] def _run_additional_iterations_on_outliers(self, fps: FixedPoints) -> FixedPoints: """Run additional optimization iterations on q outliers. Args: fps: FixedPoints object. Returns: FixedPoints object with improved outlier estimates. """ outlier_min_q = np.median(fps.qstar) * self.outlier_q_scale idx_outliers = np.where(fps.qstar > outlier_min_q)[0] if len(idx_outliers) == 0: return fps self._print_if_verbose( f"\n\tDetected {len(idx_outliers)} putative q outliers (q > {outlier_min_q:.2e})." ) self._print_if_verbose("\tPerforming sequential optimizations on outliers...") outlier_fps = fps[idx_outliers] improved_fps = self._run_sequential_optimizations( outlier_fps.xstar, outlier_fps.inputs, outlier_fps.cond_id, ) # Update iteration counts improved_fps.n_iters += outlier_fps.n_iters # Replace outliers in original fps fps.xstar[idx_outliers] = improved_fps.xstar fps.F_xstar[idx_outliers] = improved_fps.F_xstar fps.qstar[idx_outliers] = improved_fps.qstar fps.dq[idx_outliers] = improved_fps.dq fps.n_iters[idx_outliers] = improved_fps.n_iters return fps @staticmethod def _concatenate_fps(fps_list) -> FixedPoints: """Concatenate a list of FixedPoints objects. Args: fps_list: List of FixedPoints objects. Returns: Single concatenated FixedPoints object. """ if len(fps_list) == 0: return FixedPoints() def cat_attr(attr_name): vals = [getattr(fp, attr_name) for fp in fps_list] if all(v is None for v in vals): return None return np.concatenate([v for v in vals if v is not None], axis=0) first = fps_list[0] return FixedPoints( xstar=cat_attr("xstar"), F_xstar=cat_attr("F_xstar"), x_init=cat_attr("x_init"), inputs=cat_attr("inputs"), qstar=cat_attr("qstar"), dq=cat_attr("dq"), n_iters=cat_attr("n_iters"), J_xstar=cat_attr("J_xstar"), dFdu=cat_attr("dFdu"), eigval_J_xstar=cat_attr("eigval_J_xstar"), eigvec_J_xstar=cat_attr("eigvec_J_xstar"), is_stable=cat_attr("is_stable"), cond_id=cat_attr("cond_id"), tol_unique=first.tol_unique, dtype=first.dtype, ) def _print_if_verbose(self, msg: str): """Print message if verbose mode is enabled.""" if self.verbose: print(msg) @staticmethod def _print_iter_update( iter_count: int, t_start: float, q: np.ndarray, dq: np.ndarray, lr: float, is_final: bool = False, ): """Print optimization iteration update.""" t_elapsed = time.time() - t_start avg_iter_time = t_elapsed / iter_count if is_final: print(f"\t\t{iter_count} iters, ", end="") else: print(f"\tIter: {iter_count}, ", end="") if q.size == 1: print(f"q = {q[0]:.2e}, dq = {dq[0]:.2e}, ", end="") else: print( f"q = {np.mean(q):.2e} +/- {np.std(q):.2e}, " f"dq = {np.mean(dq):.2e} +/- {np.std(dq):.2e}, ", end="", ) print(f"lr = {lr:.2e}, avg iter time = {avg_iter_time:.2e} sec", end="") if is_final: print() # Newline else: print(".") # Continue line