RNN 不动点分析说明书:以 FlipFlop 任务为例

目标: 本文档将作为一份“说明书”,详细介绍如何使用 FixedPointFinder 工具,来分析在 FlipFlop 任务上训练的 RNN。

结构:

  1. 原理介绍:什么是固定点?

  2. 环境设置:导入脚本所需要的库

  3. 组件定义:逐一介绍 FlipFlopDataFlipFlopRNNtrain_flipflop_rnn 函数。

  4. 核心用法:演示 FixedPointFinder 的具体使用流程。

1. 原理介绍:什么是固定点?

固定点 (Fixed Point) 是动力学系统中的一个核心概念。对于一个 RNN,我们可以将其视为一个函数 h_t+1 = F(h_t, u_t),其中 h 是隐藏状态,u 是输入。

输入 ``u`` 保持恒定时(例如,在 FlipFlop 任务中没有脉冲输入的“记忆”阶段,u=0),系统演变为 h_t+1 = F(h_t)

一个固定点 ``x*`` 就是满足 x* = F(x*) 的状态。

  • 稳定固定点 (Stable Fixed Point):例如一个“吸引子”。如果 RNN 的状态 h 跑到了 x* 附近,它最终会停留在 x*

  • 不稳定固定点 (Unstable Fixed Point):例如一个“排斥子”或“鞍点”。

核心原理: 训练 RNN 完成 FlipFlop 任务。当训练成功后,RNN 会学会为它需要“记忆”的每一种状态(例如 [+1, +1][+1, -1])都创造一个稳定固定点。当输入 u=0 时,RNN 状态会自动流向并停留在这些不动点上,从而实现“记忆”功能。

本教程的目的就是使用 FixedPointFinder 工具,把这些被 RNN “藏起来”的稳定固定点全部找出来。

2. 导入

导入 flipflop_fixed_points.py 脚本中使用的所有库。

[ ]:
import brainpy as bp
import brainpy.math as bm
import jax
import jax.numpy as jnp
import numpy as np
import random
from canns.analyzer.plotting import PlotConfig
from canns.analyzer.slow_points import FixedPointFinder, save_checkpoint, load_checkpoint, plot_fixed_points_2d, plot_fixed_points_3d

3. 组件定义:数据、模型与训练

这部分,我们完整定义了 flipflop_fixed_points.py 脚本中的三个核心组件。

3.1 组件 1:FlipFlopData 类

这是 flipflop_fixed_points.py 脚本中的 FlipFlopData 类。

[2]:
class FlipFlopData:
    """Generator for flip-flop memory task data."""

    def __init__(self, n_bits=3, n_time=64, p=0.5, random_seed=0):
        """Initialize FlipFlopData generator.

        Args:
            n_bits: Number of memory channels.
            n_time: Number of timesteps per trial.
            p: Probability of input pulse at each timestep.
            random_seed: Random seed for reproducibility.
        """
        self.rng = np.random.RandomState(random_seed)
        self.n_time = n_time
        self.n_bits = n_bits
        self.p = p

    def generate_data(self, n_trials):
        """Generate flip-flop task data.

        Args:
            n_trials: Number of trials to generate.

        Returns:
            dict with 'inputs' and 'targets' arrays [n_trials x n_time x n_bits].
        """
        n_time = self.n_time
        n_bits = self.n_bits
        p = self.p

        # Generate unsigned input pulses
        unsigned_inputs = self.rng.binomial(1, p, [n_trials, n_time, n_bits])

        # Ensure every trial starts with a pulse
        unsigned_inputs[:, 0, :] = 1

        # Generate random signs {-1, +1}
        random_signs = 2 * self.rng.binomial(1, 0.5, [n_trials, n_time, n_bits]) - 1

        # Apply random signs
        inputs = unsigned_inputs * random_signs

        # Compute targets
        targets = np.zeros([n_trials, n_time, n_bits])
        for trial_idx in range(n_trials):
            for bit_idx in range(n_bits):
                input_seq = inputs[trial_idx, :, bit_idx]
                t_flip = np.where(input_seq != 0)[0]
                for flip_idx in range(len(t_flip)):
                    t_flip_i = t_flip[flip_idx]
                    targets[trial_idx, t_flip_i:, bit_idx] = inputs[
                        trial_idx, t_flip_i, bit_idx
                    ]

        return {
            "inputs": inputs.astype(np.float32),
            "targets": targets.astype(np.float32),
        }

3.2 组件 2:FlipFlopRNN 类

这是 flipflop_fixed_points.py 脚本中的 FlipFlopRNN 类。

用法说明: FixedPointFinder原理是寻找 x = F(x, u)。为了计算 F(x, u),它会调用 rnn_model(inputs, hidden)FixedPointFinder 会传入 inputs 形状为 [batch, 1, n_inputs]hidden 形状为 [batch, n_hidden]

因此,你的 __call__ 方法必须能处理 n_time == 1 的情况,并返回 (outputs, h_next)

请看下面代码中 if n_time == 1: 这个分支,这正是为了适配 FixedPointFinder 而设计的具体用法。

注意: 该类现在继承自 bp.DynamicalSystem(而非 bp.nn.Module),并使用 bm.Variable(而非 bp.ParamState)进行状态管理,遵循现代 BrainPy API。

[ ]:
class FlipFlopRNN(bp.DynamicalSystem):
    """RNN model for the flip-flop memory task."""

    def __init__(self, n_inputs, n_hidden, n_outputs, rnn_type="gru", seed=0):
        """Initialize FlipFlop RNN.

        Args:
            n_inputs: Number of input channels.
            n_hidden: Number of hidden units.
            n_outputs: Number of output channels.
            rnn_type: Type of RNN cell ('tanh', 'gru').
            seed: Random seed for weight initialization.
        """
        super().__init__()
        self.n_inputs = n_inputs
        self.n_hidden = n_hidden
        self.n_outputs = n_outputs
        self.rnn_type = rnn_type.lower()

        # Initialize RNN cell parameters
        key = jax.random.PRNGKey(seed)
        k1, k2, k3, k4 = jax.random.split(key, 4)

        if rnn_type == "tanh":
            # Simple tanh RNN
            self.w_ih = bm.Variable(
                jax.random.normal(k1, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hh = bm.Variable(
                jax.random.normal(k2, (n_hidden, n_hidden)) * 0.5
            )
            self.b_h = bm.Variable(jnp.zeros(n_hidden))
        elif rnn_type == "gru":
            # GRU cell
            self.w_ir = bm.Variable(
                jax.random.normal(k1, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hr = bm.Variable(
                jax.random.normal(k2, (n_hidden, n_hidden)) * 0.5
            )
            self.w_iz = bm.Variable(
                jax.random.normal(k3, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hz = bm.Variable(
                jax.random.normal(k4, (n_hidden, n_hidden)) * 0.5
            )
            k5, k6, k7, k8 = jax.random.split(k4, 4)
            self.w_in = bm.Variable(
                jax.random.normal(k5, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hn = bm.Variable(
                jax.random.normal(k6, (n_hidden, n_hidden)) * 0.5
            )
            self.b_r = bm.Variable(jnp.zeros(n_hidden))
            self.b_z = bm.Variable(jnp.zeros(n_hidden))
            self.b_n = bm.Variable(jnp.zeros(n_hidden))
        else:
            raise ValueError(f"Unsupported rnn_type: {rnn_type}")

        # Readout layer
        self.w_out = bm.Variable(
            jax.random.normal(k3, (n_hidden, n_outputs)) * 0.1
        )
        self.b_out = bm.Variable(jnp.zeros(n_outputs))

        # Initial hidden state
        self.h0 = bm.Variable(jnp.zeros(n_hidden))

    def step(self, x_t, h):
        """Single RNN step.

        Args:
            x_t: [batch_size x n_inputs] input at time t.
            h: [batch_size x n_hidden] hidden state.

        Returns:
            h_next: [batch_size x n_hidden] next hidden state.
        """
        if self.rnn_type == "tanh":
            # Simple tanh RNN step
            h_next = jnp.tanh(
                x_t @ self.w_ih.value + h @ self.w_hh.value + self.b_h.value
            )
        elif self.rnn_type == "gru":
            # GRU step
            r = jax.nn.sigmoid(
                x_t @ self.w_ir.value + h @ self.w_hr.value + self.b_r.value
            )
            z = jax.nn.sigmoid(
                x_t @ self.w_iz.value + h @ self.w_hz.value + self.b_z.value
            )
            n = jnp.tanh(
                x_t @ self.w_in.value + (r * h) @ self.w_hn.value + self.b_n.value
            )
            h_next = (1 - z) * n + z * h
        else:
            raise ValueError(f"Unknown rnn_type: {self.rnn_type}")

        return h_next

    def __call__(self, inputs, hidden=None):
        """Forward pass through the RNN. Optimized with jax.lax.scan."""
        batch_size = inputs.shape[0]
        n_time = inputs.shape[1]

        # Initialize hidden state
        if hidden is None:
            h = jnp.tile(self.h0.value, (batch_size, 1))
        else:
            h = hidden

        # Single-step computation mode for the fixed-point finder
        if n_time == 1:
            x_t = inputs[:, 0, :]
            h_next = self.step(x_t, h)
            y = h_next @ self.w_out.value + self.b_out.value
            return y[:, None, :], h_next

        # Full sequence case
        def scan_fn(carry, x_t):
            """Single-step scan function"""
            h_prev = carry
            h_next = self.step(x_t, h_prev)
            y_t = h_next @ self.w_out.value + self.b_out.value
            return h_next, (y_t, h_next)

        # (batch, time, features) -> (time, batch, features)
        inputs_transposed = inputs.transpose(1, 0, 2)

        # Run the scan
        _, (outputs_seq, hiddens_seq) = jax.lax.scan(scan_fn, h, inputs_transposed)

        outputs = outputs_seq.transpose(1, 0, 2)
        hiddens = hiddens_seq.transpose(1, 0, 2)

        return outputs, hiddens

3.3 组件 3:train_flipflop_rnn 函数

这是 flipflop_fixed_points.py 脚本中的 train_flipflop_rnn 函数。

主要更新:

  • 使用现代 BrainPy 优化器 API:bp.optimizers.Adam(lr=..., train_vars=...)

  • 处理参数名称映射(vars() 返回完整名称如 ‘FlipFlopRNN0.w_ih’)

  • 不再使用 braintools 或已弃用的 register_trainable_weights()

[4]:
def train_flipflop_rnn(rnn, train_data, valid_data,
                       learning_rate=0.08,
                       batch_size=128,
                       max_epochs=1000,
                       min_loss=1e-4,
                       print_every=10):
    print("\n" + "=" * 70)
    print("Training FlipFlop RNN (Using brainpy optimizer)")
    print("=" * 70)

    # Prepare data
    train_inputs = jnp.array(train_data['inputs'])
    train_targets = jnp.array(train_data['targets'])
    valid_inputs = jnp.array(valid_data['inputs'])
    valid_targets = jnp.array(valid_data['targets'])
    n_train = train_inputs.shape[0]
    n_batches = n_train // batch_size

    # Get trainable variables from the model
    # Note: vars() returns keys like 'FlipFlopRNN0.w_ih', we need just 'w_ih' for computation
    train_vars = {name: var for name, var in rnn.vars().items() if isinstance(var, bm.Variable)}
    # Create mapping between short names and full names
    name_mapping = {name.split('.')[-1]: name for name in train_vars.keys()}
    # Extract just the parameter name (after the last dot) for gradient computation
    params = {name.split('.')[-1]: var.value for name, var in train_vars.items()}

    # Initialize optimizer with train_vars parameter (modern brainpy API)
    optimizer = bp.optimizers.Adam(lr=learning_rate, train_vars=train_vars)

    # Define JIT-compiled gradient step
    @jax.jit
    def grad_step(params, batch_inputs, batch_targets):
        """Pure function to compute loss and gradients"""
        def forward_pass(p, inputs):
            batch_size = inputs.shape[0]
            h = jnp.tile(p['h0'], (batch_size, 1))

            def scan_fn(carry, x_t):
                h_prev = carry
                if rnn.rnn_type == "tanh":
                    h_next = jnp.tanh(x_t @ p['w_ih'] + h_prev @ p['w_hh'] + p['b_h'])
                elif rnn.rnn_type == "gru":
                    r = jax.nn.sigmoid(x_t @ p['w_ir'] + h_prev @ p['w_hr'] + p['b_r'])
                    z = jax.nn.sigmoid(x_t @ p['w_iz'] + h_prev @ p['w_hz'] + p['b_z'])
                    n = jnp.tanh(x_t @ p['w_in'] + (r * h_prev) @ p['w_hn'] + p['b_n'])
                    h_next = (1 - z) * n + z * h_prev
                else:
                    h_next = h_prev
                y_t = h_next @ p['w_out'] + p['b_out']
                return h_next, y_t

            inputs_transposed = inputs.transpose(1, 0, 2)
            _, outputs_seq = jax.lax.scan(scan_fn, h, inputs_transposed)
            outputs = outputs_seq.transpose(1, 0, 2)
            return outputs

        def loss_fn(p):
            outputs = forward_pass(p, batch_inputs)
            return jnp.mean((outputs - batch_targets) ** 2)

        loss_val, grads = jax.value_and_grad(loss_fn)(params)
        return loss_val, grads

    losses = []
    print("\nTraining parameters:")
    print(f"  Batch size: {batch_size}")
    print(f"  Learning rate:{learning_rate:.6f} (Fixed)")

    for epoch in range(max_epochs):
        perm = np.random.permutation(n_train)
        epoch_loss = 0.0
        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size
            batch_inputs = train_inputs[perm[start_idx:end_idx]]
            batch_targets = train_targets[perm[start_idx:end_idx]]
            loss_val, grads_short = grad_step(params, batch_inputs, batch_targets)
            # Map gradients back to full names for optimizer
            grads = {name_mapping[short_name]: grad for short_name, grad in grads_short.items()}
            optimizer.update(grads)
            # Update params with current variable values (extract parameter names)
            params = {name.split('.')[-1]: var.value for name, var in train_vars.items()}
            epoch_loss += float(loss_val)
        epoch_loss /= n_batches
        losses.append(epoch_loss)

        if epoch % print_every == 0:
            valid_outputs, _ = rnn(valid_inputs)
            valid_loss = float(jnp.mean((valid_outputs - valid_targets) ** 2))
            print(f"Epoch {epoch:4d}: train_loss = {epoch_loss:.6f}, "
                  f"valid_loss = {valid_loss:.6f}")
        if epoch_loss < min_loss:
            print(f"\nReached target loss {min_loss:.2e} at epoch {epoch}")
            break

    # Training complete
    valid_outputs, _ = rnn(valid_inputs)
    final_valid_loss = float(jnp.mean((valid_outputs - valid_targets) ** 2))
    print("\n" + "=" * 70)
    print("Training Complete!")
    print("=" * 70)
    print(f"Final training loss: {epoch_loss:.6f}")
    print(f"Final validation loss: {final_valid_loss:.6f}")
    print(f"Total epochs: {epoch + 1}")
    return losses

4. 核心用法:训练并查找固定点

我们将复现 flipflop_fixed_points.py 脚本中的 main 函数和 if __name__ == "__main__": 块中的逻辑。

我们将:

  1. 定义任务配置。

  2. 设置参数并生成数据。

  3. 训练或加载(如果存在)模型。

  4. 初始化并运行 FixedPointFinder

  5. 打印结果并可视化。

4.1 第 1 步:定义配置和参数

这部分代码来自 flipflop_fixed_points.py 的全局 TASK_CONFIGS 字典和 if __name__ == "__main__": 块,以及 main 函数的开头部分。

[5]:
# Configuration Dictionary
TASK_CONFIGS = {
    "2_bit": {
        "n_bits": 2,
        "n_hidden": 3,
        "n_trials_train": 512,
        "n_inits":1024,
    },
    "3_bit": {
        "n_bits": 3,
        "n_hidden": 4,
        "n_trials_train": 512,
        "n_inits":1024,
    },
    "4_bit": {
        "n_bits": 4,
        "n_hidden": 6,
        "n_trials_train": 512,
        "n_inits":1024,
    },
}

# --- 设置参数 ---
# (这部分逻辑来自原始脚本的 if __name__ == "__main__" 块)
config_to_run = "3_bit"  # 指定要运行的配置
seed_to_use = 42       # 使用固定种子

config_name = config_to_run
seed = seed_to_use

# (这部分逻辑来自原始脚本的 main 函数)
if config_name not in TASK_CONFIGS:
    raise ValueError(f"Unknown config_name: {config_name}. Available: {list(TASK_CONFIGS.keys())}")
config = TASK_CONFIGS[config_name]

# Set random seeds
np.random.seed(seed)
random.seed(seed)

print(f"\n--- Running FlipFlop Task ({config_name}) ---")
print(f"Seed: {seed}")

n_bits = config["n_bits"]
n_hidden = config["n_hidden"]
n_trials_train = config["n_trials_train"]
n_inits = config["n_inits"]

n_time = 64
n_trials_valid = 128
n_trials_test = 128
rnn_type = "tanh"
learning_rate = 0.08
batch_size = 128
max_epochs = 500 # (原始为 1000,500 可以在 Notebook 中跑得更快)
min_loss = 1e-4

--- Running FlipFlop Task (3_bit) ---
Seed: 42

4.2 第 2 步:生成数据并训练模型

这部分代码来自 flipflop_fixed_points.pymain 函数。

[6]:
# Generate data
data_gen = FlipFlopData(n_bits=n_bits, n_time=n_time, p=0.5, random_seed=seed)
train_data = data_gen.generate_data(n_trials_train)
valid_data = data_gen.generate_data(n_trials_valid)
test_data = data_gen.generate_data(n_trials_test)

# Create RNN model
rnn = FlipFlopRNN(n_inputs=n_bits, n_hidden=n_hidden, n_outputs=n_bits, rnn_type=rnn_type, seed=seed)

# Check for checkpoint
checkpoint_path = f"flipflop_rnn_{config_name}_checkpoint.msgpack"
if load_checkpoint(rnn, checkpoint_path):
    print(f"Loaded model from checkpoint: {checkpoint_path}")
else:
    # Train the RNN
    print(f"No checkpoint found ({checkpoint_path}). Training...")
    losses = train_flipflop_rnn(
        rnn,
        train_data,
        valid_data,
        learning_rate=learning_rate,
        batch_size=batch_size,
        max_epochs=max_epochs,
        min_loss=min_loss,
        print_every=10
    )
No checkpoint found (flipflop_rnn_3_bit_checkpoint.msgpack). Training...

======================================================================
Training FlipFlop RNN (Using bts Scheduler & built-in Grad Clip)
======================================================================

Training parameters:
  Batch size: 128
  Learning rate:0.080000 (Fixed)
Epoch    0: train_loss = 0.946821, valid_loss = 0.810312, lr = 0.080000
Epoch   10: train_loss = 0.037047, valid_loss = 0.025001, lr = 0.080000
Epoch   20: train_loss = 0.002670, valid_loss = 0.002536, lr = 0.080000
Epoch   30: train_loss = 0.001543, valid_loss = 0.001499, lr = 0.080000
Epoch   40: train_loss = 0.001115, valid_loss = 0.001087, lr = 0.080000
Epoch   50: train_loss = 0.000843, valid_loss = 0.000825, lr = 0.080000
Epoch   60: train_loss = 0.000657, valid_loss = 0.000642, lr = 0.080000
Epoch   70: train_loss = 0.000521, valid_loss = 0.000508, lr = 0.080000
Epoch   80: train_loss = 0.000423, valid_loss = 0.000413, lr = 0.080000
Epoch   90: train_loss = 0.000350, valid_loss = 0.000340, lr = 0.080000
Epoch  100: train_loss = 0.000293, valid_loss = 0.000286, lr = 0.080000
Epoch  110: train_loss = 0.000248, valid_loss = 0.000242, lr = 0.080000
Epoch  120: train_loss = 0.000214, valid_loss = 0.000209, lr = 0.080000
Epoch  130: train_loss = 0.000187, valid_loss = 0.000183, lr = 0.080000
Epoch  140: train_loss = 0.000166, valid_loss = 0.000161, lr = 0.080000
Epoch  150: train_loss = 0.000148, valid_loss = 0.000144, lr = 0.080000
Epoch  160: train_loss = 0.000133, valid_loss = 0.000129, lr = 0.080000
Epoch  170: train_loss = 0.000120, valid_loss = 0.000117, lr = 0.080000
Epoch  180: train_loss = 0.000109, valid_loss = 0.000108, lr = 0.080000
Epoch  190: train_loss = 0.000100, valid_loss = 0.000099, lr = 0.080000

Reached target loss 1.00e-04 at epoch 191

======================================================================
Training Complete!
======================================================================
Final training loss: 0.000099
Final validation loss: 0.000098
Total epochs: 192

4.3 第 3 步:运行固定点分析

这部分是 FixedPointFinder具体用法,来自 main 函数。

用法说明:

  1. 收集状态轨迹 (State Trajectory)hiddens_npFixedPointFinder 会从这些“真实”的状态中采样初始点。

  2. 初始化 ``FixedPointFinder``

    • rnn_model:传入 rnn 实例。

    • do_compute_jacobians=True:必须设为 True。这会计算雅可比矩阵 J = dF/dx

    • do_decompose_jacobians=True:必须设为 True。这会计算 J 的特征值,用于判断稳定性

  3. 运行 ``find_fixed_points``

    • state_traj:传入 hiddens_np

    • inputs:我们要找的是“记忆”状态,即没有输入时的固定点。因此我们传入一个恒定的零向量 constant_input

[7]:
# Fixed Point Analysis
print("\n--- Fixed Point Analysis ---")
inputs_jax = jnp.array(test_data["inputs"])
outputs, hiddens = rnn(inputs_jax)
hiddens_np = np.array(hiddens)

# Find fixed points
finder = FixedPointFinder(
    rnn,
    method="joint",
    max_iters=5000,
    lr_init=0.02,
    tol_q=1e-4,
    final_q_threshold=1e-6,
    tol_unique=1e-2,
    do_compute_jacobians=True,
    do_decompose_jacobians=True,
    outlier_distance_scale=10.0,
    verbose=True,
    super_verbose=True,
)

constant_input = np.zeros((1, n_bits), dtype=np.float32)

unique_fps, all_fps = finder.find_fixed_points(
    state_traj=hiddens_np,
    inputs=constant_input,
    n_inits=n_inits,
    noise_scale=0.4,
)

--- Fixed Point Analysis ---

Searching for fixed points from 1024 initial states.

        Finding fixed points via joint optimization.
/var/folders/x0/_jqxxbbn0rsdn6b4h6fxbrjr0000gn/T/ipykernel_42544/1298414900.py:25: UserWarning: Joint optimization with n_inits=1024 may be inefficient and use excessive memory. Consider using sequential optimization or reducing n_inits.
  unique_fps, all_fps = finder.find_fixed_points(
        Iter: 100, q = 1.84e-04 +/- 1.43e-03, dq = 1.67e-05 +/- 1.23e-04, lr = 2.00e-02, avg iter time = 1.18e-02 sec.
        Optimization complete to desired tolerance.
                184 iters, q = 1.24e-07 +/- 3.13e-06, dq = 1.11e-08 +/- 2.42e-07, lr = 2.00e-02, avg iter time = 8.81e-03 sec
        Identified 26 unique fixed points.
        Computing recurrent Jacobian at 26 unique fixed points.
        Computing input Jacobian at 26 unique fixed points.
Decomposing 26 Jacobians...
Found 9 stable and 17 unstable fixed points.
        Applying final q-value filter (q < 1.0e-06)...
                Excluded 1 low-quality fixed points.
                25 high-quality fixed points remain.
        Fixed point finding complete.

4.4 结果分析与可视化

find_fixed_points 返回两个对象:

  • all_fps: 包含了从 n_inits 个初始点出发找到的所有结果。

  • unique_fps: 我们最关心的结果。经过 tol_unique 过滤后的、不重复的固定点集合。

如何解读:

  • unique_fps.n: 找到的独特固定点的数量。

  • unique_fps.qstar: q 值。越接近 0 越好。

  • unique_fps.is_stable: (关键) 是否为稳定固定点。

对于 N-bit 任务,我们期望找到 2^N 个稳定固定点(代表 2^N 个记忆状态)。

下面的代码单元格整合了 flipflop_fixed_points.py 脚本中 main 函数的末尾 和 if __name__ == "__main__": 块的最后一行,用于打印所有分析结果并生成图表。

[8]:
# Print results
print("\n--- Fixed Point Analysis Results ---")
unique_fps.print_summary()

if unique_fps.n > 0:
    print(f"\nDetailed Fixed Point Information (Top 10):")
    print(f"{'#':<4} {'q-value':<12} {'Stability':<12} {'Max |eig|':<12}")
    print("-" * 45)
    for i in range(min(10, unique_fps.n)):
        stability_str = "Stable" if unique_fps.is_stable[i] else "Unstable"
        max_eig = np.abs(unique_fps.eigval_J_xstar[i, 0])
        print(
            f"{i + 1:<4} {unique_fps.qstar[i]:<12.2e} {stability_str:<12} {max_eig:<12.4f}"
        )

    # Visualize fixed points - 2D
    config_2d = PlotConfig(
        title=f"FlipFlop Fixed Points ({config_name} - 2D PCA)",
        xlabel="PC 1", ylabel="PC 2", figsize=(10, 8),
        show=True
    )
    plot_fixed_points_2d(unique_fps, hiddens_np, config=config_2d)

    # Visualize fixed points - 3D
    config_3d = PlotConfig(
        title=f"FlipFlop Fixed Points ({config_name} - 3D PCA)",
        figsize=(12, 10),
        show=True
    )
    plot_fixed_points_3d(
        unique_fps, hiddens_np, config=config_3d,
        plot_batch_idx=list(range(30)), plot_start_time=10
    )

print("\n--- Analysis complete ---")

--- Fixed Point Analysis Results ---

=== Fixed Points Summary ===
Number of fixed points: 25
State dimension: 4
Input dimension: 3

q values: min=1.84e-12, median=4.63e-11, max=1.46e-08
Iterations: min=184, median=184, max=184

Stable fixed points: 8 / 25

Detailed Fixed Point Information (Top 10):
#    q-value      Stability    Max |eig|
---------------------------------------------
1    7.47e-12     Stable       0.1546
2    4.66e-12     Stable       0.2300
3    2.87e-12     Stable       0.2611
4    7.70e-11     Unstable     1.9137
5    8.48e-12     Unstable     2.1957
6    8.06e-12     Stable       0.1676
7    8.17e-12     Stable       0.1617
8    1.84e-12     Stable       0.2382
9    9.57e-12     Stable       0.1491
10   2.75e-12     Stable       0.2520
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_19_1.png
  PCA explained variance: [0.48812193 0.26487482 0.24671465]
  Total variance explained: 99.97%
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_19_3.png

--- Analysis complete ---

5. 多配置对比:2-bit, 3-bit, 4-bit

下面我们将运行所有三种配置,展示不同复杂度任务的固定点分析结果。

预期结果:

  • 2-bit: 4 个稳定固定点 (2² = 4 种记忆状态)

  • 3-bit: 8 个稳定固定点 (2³ = 8 种记忆状态)

  • 4-bit: 16 个稳定固定点 (2⁴ = 16 种记忆状态)

[11]:
import matplotlib.pyplot as plt

def run_flipflop_analysis(config_name, seed=42):
    """运行单个配置的完整分析流程"""
    config = TASK_CONFIGS[config_name]
    n_bits = config["n_bits"]
    n_hidden = config["n_hidden"]
    n_trials_train = config["n_trials_train"]
    n_inits = config["n_inits"]

    # 设置随机种子
    np.random.seed(seed)
    random.seed(seed)

    print(f"\n{'='*60}")
    print(f"配置: {config_name} ({n_bits} bits, {n_hidden} hidden units)")
    print(f"{'='*60}")

    # 生成数据
    data_gen = FlipFlopData(n_bits=n_bits, n_time=64, p=0.5, random_seed=seed)
    train_data = data_gen.generate_data(n_trials_train)
    valid_data = data_gen.generate_data(128)
    test_data = data_gen.generate_data(128)

    # 创建并训练模型
    rnn = FlipFlopRNN(n_inputs=n_bits, n_hidden=n_hidden,
                      n_outputs=n_bits, rnn_type="tanh", seed=seed)

    checkpoint_path = f"flipflop_rnn_{config_name}_checkpoint.msgpack"
    if not load_checkpoint(rnn, checkpoint_path):
        print(f"训练模型...")
        train_flipflop_rnn(rnn, train_data, valid_data,
                          learning_rate=0.08, batch_size=128,
                          max_epochs=500, min_loss=1e-4, print_every=50)
    else:
        print(f"从检查点加载模型: {checkpoint_path}")

    # 获取隐藏状态轨迹
    inputs_jax = jnp.array(test_data["inputs"])
    outputs, hiddens = rnn(inputs_jax)
    hiddens_np = np.array(hiddens)

    # 固定点分析
    finder = FixedPointFinder(
        rnn, method="joint", max_iters=5000, lr_init=0.02,
        tol_q=1e-4, final_q_threshold=1e-6, tol_unique=1e-2,
        do_compute_jacobians=True, do_decompose_jacobians=True,
        outlier_distance_scale=10.0, verbose=True, super_verbose=False,
    )

    constant_input = np.zeros((1, n_bits), dtype=np.float32)
    unique_fps, _ = finder.find_fixed_points(
        state_traj=hiddens_np, inputs=constant_input,
        n_inits=n_inits, noise_scale=0.4,
    )

    return unique_fps, hiddens_np, config_name

# 存储所有配置的结果
all_results = {}
for cfg in ["2_bit", "3_bit", "4_bit"]:
    unique_fps, hiddens_np, name = run_flipflop_analysis(cfg, seed=43)
    all_results[cfg] = {"fps": unique_fps, "hiddens": hiddens_np}

    # 打印摘要
    n_stable = np.sum(unique_fps.is_stable) if unique_fps.n > 0 else 0
    expected = 2 ** int(cfg[0])
    print(f"\n结果: 找到 {unique_fps.n} 个固定点, 其中 {n_stable} 个稳定 (期望: {expected})")

============================================================
配置: 2_bit (2 bits, 3 hidden units)
============================================================
训练模型...

======================================================================
Training FlipFlop RNN (Using bts Scheduler & built-in Grad Clip)
======================================================================

Training parameters:
  Batch size: 128
  Learning rate:0.080000 (Fixed)
Epoch    0: train_loss = 0.949601, valid_loss = 0.805686, lr = 0.080000
Epoch   50: train_loss = 0.001212, valid_loss = 0.001192, lr = 0.080000
Epoch  100: train_loss = 0.000509, valid_loss = 0.000503, lr = 0.080000
Epoch  150: train_loss = 0.000290, valid_loss = 0.000287, lr = 0.080000
Epoch  200: train_loss = 0.000190, valid_loss = 0.000188, lr = 0.080000
Epoch  250: train_loss = 0.000134, valid_loss = 0.000133, lr = 0.080000
Epoch  300: train_loss = 0.000100, valid_loss = 0.000099, lr = 0.080000

Reached target loss 1.00e-04 at epoch 300

======================================================================
Training Complete!
======================================================================
Final training loss: 0.000100
Final validation loss: 0.000099
Total epochs: 301

Searching for fixed points from 1024 initial states.

        Finding fixed points via joint optimization.
/var/folders/x0/_jqxxbbn0rsdn6b4h6fxbrjr0000gn/T/ipykernel_42544/65684184.py:52: UserWarning: Joint optimization with n_inits=1024 may be inefficient and use excessive memory. Consider using sequential optimization or reducing n_inits.
  unique_fps, _ = finder.find_fixed_points(
        Optimization complete to desired tolerance.
                178 iters, q = 1.39e-07 +/- 2.73e-06, dq = 2.51e-08 +/- 6.28e-07, lr = 2.00e-02, avg iter time = 6.00e-03 sec
        Identified 10 unique fixed points.
        Computing recurrent Jacobian at 10 unique fixed points.
        Computing input Jacobian at 10 unique fixed points.
Decomposing 10 Jacobians...
Found 5 stable and 5 unstable fixed points.
        Applying final q-value filter (q < 1.0e-06)...
                Excluded 1 low-quality fixed points.
                9 high-quality fixed points remain.
        Fixed point finding complete.


结果: 找到 9 个固定点, 其中 4 个稳定 (期望: 4)

============================================================
配置: 3_bit (3 bits, 4 hidden units)
============================================================
训练模型...

======================================================================
Training FlipFlop RNN (Using bts Scheduler & built-in Grad Clip)
======================================================================

Training parameters:
  Batch size: 128
  Learning rate:0.080000 (Fixed)
Epoch    0: train_loss = 0.934688, valid_loss = 0.771106, lr = 0.080000
Epoch   50: train_loss = 0.000601, valid_loss = 0.000595, lr = 0.080000
Epoch  100: train_loss = 0.000243, valid_loss = 0.000240, lr = 0.080000
Epoch  150: train_loss = 0.000138, valid_loss = 0.000137, lr = 0.080000

Reached target loss 1.00e-04 at epoch 188

======================================================================
Training Complete!
======================================================================
Final training loss: 0.000100
Final validation loss: 0.000099
Total epochs: 189

Searching for fixed points from 1024 initial states.

        Finding fixed points via joint optimization.
        Optimization complete to desired tolerance.
                209 iters, q = 2.18e-07 +/- 4.03e-06, dq = 2.30e-08 +/- 4.08e-07, lr = 2.00e-02, avg iter time = 4.65e-03 sec
        Identified 27 unique fixed points.
        Computing recurrent Jacobian at 27 unique fixed points.
        Computing input Jacobian at 27 unique fixed points.
Decomposing 27 Jacobians...
Found 8 stable and 19 unstable fixed points.
        Applying final q-value filter (q < 1.0e-06)...
                Excluded 2 low-quality fixed points.
                25 high-quality fixed points remain.
        Fixed point finding complete.


结果: 找到 25 个固定点, 其中 8 个稳定 (期望: 8)

============================================================
配置: 4_bit (4 bits, 6 hidden units)
============================================================
训练模型...

======================================================================
Training FlipFlop RNN (Using bts Scheduler & built-in Grad Clip)
======================================================================

Training parameters:
  Batch size: 128
  Learning rate:0.080000 (Fixed)
Epoch    0: train_loss = 0.930025, valid_loss = 0.750363, lr = 0.080000
Epoch   50: train_loss = 0.000319, valid_loss = 0.000313, lr = 0.080000
Epoch  100: train_loss = 0.000114, valid_loss = 0.000112, lr = 0.080000

Reached target loss 1.00e-04 at epoch 109

======================================================================
Training Complete!
======================================================================
Final training loss: 0.000099
Final validation loss: 0.000097
Total epochs: 110

Searching for fixed points from 1024 initial states.

        Finding fixed points via joint optimization.
        Optimization complete to desired tolerance.
                383 iters, q = 9.75e-08 +/- 3.03e-06, dq = 1.01e-08 +/- 3.16e-07, lr = 2.00e-02, avg iter time = 4.41e-03 sec
        Identified 67 unique fixed points.
        Computing recurrent Jacobian at 67 unique fixed points.
        Computing input Jacobian at 67 unique fixed points.
Decomposing 67 Jacobians...
Found 16 stable and 51 unstable fixed points.
        Applying final q-value filter (q < 1.0e-06)...
                Excluded 1 low-quality fixed points.
                66 high-quality fixed points remain.
        Fixed point finding complete.


结果: 找到 66 个固定点, 其中 16 个稳定 (期望: 16)

5.1 2D 可视化对比

展示三种配置的 2D PCA 投影,可以直观看到固定点随任务复杂度增加而增多。

[12]:
# 2D 可视化 - 分别展示每个配置
for cfg in ["2_bit", "3_bit", "4_bit"]:
    result = all_results[cfg]
    unique_fps = result["fps"]
    hiddens_np = result["hiddens"]

    n_bits = int(cfg[0])
    n_stable = np.sum(unique_fps.is_stable) if unique_fps.n > 0 else 0

    config_2d = PlotConfig(
        title=f"FlipFlop {cfg}: {n_stable} stable fixed points (2D PCA)",
        xlabel="PC 1", ylabel="PC 2",
        figsize=(8, 6),
        show=True
    )

    plot_fixed_points_2d(unique_fps, hiddens_np, config=config_2d)
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_23_0.png
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_23_1.png
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_23_2.png

5.2 3D 可视化对比

3D PCA 投影展示了隐藏状态轨迹和固定点在三维空间中的分布,可以更清晰地看到 RNN 的动力学结构。

[13]:
# 3D 可视化 - 分别展示每个配置
for cfg in ["2_bit", "3_bit", "4_bit"]:
    result = all_results[cfg]
    unique_fps = result["fps"]
    hiddens_np = result["hiddens"]

    n_bits = int(cfg[0])
    n_stable = np.sum(unique_fps.is_stable) if unique_fps.n > 0 else 0

    config_3d = PlotConfig(
        title=f"FlipFlop {cfg}: {n_stable} stable fixed points (3D PCA)",
        figsize=(10, 8),
        show=True
    )

    plot_fixed_points_3d(
        unique_fps, hiddens_np, config=config_3d,
        plot_batch_idx=list(range(20)), plot_start_time=10
    )
  PCA explained variance: [6.6616559e-01 3.3381739e-01 1.7039767e-05]
  Total variance explained: 100.00%
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_25_1.png
  PCA explained variance: [0.4709875  0.266616   0.26175115]
  Total variance explained: 99.94%
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_25_3.png
  PCA explained variance: [0.37155104 0.24347201 0.19847597]
  Total variance explained: 81.35%
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_25_5.png

6. 总结

本教程展示了如何使用 FixedPointFinder 分析 RNN 的动力学结构:

  1. FlipFlop 任务:RNN 需要记忆多个二进制通道的状态

  2. 固定点分析:找到 RNN 用于”记忆”的稳定状态

  3. 可视化:通过 PCA 降维展示固定点在隐藏状态空间中的分布

关键发现

  • 对于 N-bit 任务,RNN 学会创建 2^N 个稳定固定点

  • 这些固定点对应不同的记忆状态组合

  • 固定点分析是理解 RNN 内部计算机制的有力工具