Source code for src.canns.analyzer.slow_points.visualization

"""Visualization functions for fixed point analysis."""

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from sklearn.decomposition import PCA

from ..plotting.config import PlotConfig
from .fixed_points import FixedPoints

__all__ = ["plot_fixed_points_2d", "plot_fixed_points_3d"]


[docs] def plot_fixed_points_2d( fixed_points: FixedPoints, state_traj: np.ndarray, config: PlotConfig | None = None, plot_batch_idx: list[int] | None = None, plot_start_time: int = 0, ) -> Figure: """Plot fixed points and trajectories in 2D using PCA. Args: fixed_points: FixedPoints object containing analysis results. state_traj: State trajectories [n_batch x n_time x n_states]. config: Plot configuration. If None, uses default config. plot_batch_idx: Batch indices to plot trajectories. If None, plots first 30. plot_start_time: Starting time index for trajectory plotting. Returns: matplotlib Figure object. Example: >>> from canns.analyzer.slow_points import plot_fixed_points_2d, FixedPoints >>> from canns.analyzer.plotting import PlotConfig >>> config = PlotConfig( ... title="Fixed Points Analysis", ... figsize=(10, 8), ... save_path="fps_2d.png" ... ) >>> fig = plot_fixed_points_2d(unique_fps, hiddens, config=config) """ # Default config if config is None: config = PlotConfig( title="Fixed Points (2D PCA)", xlabel="PC 1", ylabel="PC 2", figsize=(10, 8), show_legend=True, ) n_batch, n_time, n_states = state_traj.shape # Default: plot first 30 trajectories if plot_batch_idx is None: plot_batch_idx = list(range(min(30, n_batch))) # Flatten trajectories for PCA all_states = [] for batch_idx in plot_batch_idx: all_states.append(state_traj[batch_idx, plot_start_time:, :]) all_states = np.concatenate(all_states, axis=0) # [n_samples x n_states] # Add fixed points if fixed_points.n > 0: all_states = np.concatenate([all_states, fixed_points.xstar], axis=0) # Perform PCA pca = PCA(n_components=2) all_pca = pca.fit_transform(all_states) # Split back n_traj_points = sum(len(state_traj[i, plot_start_time:, :]) for i in plot_batch_idx) traj_pca = all_pca[:n_traj_points] fps_pca = all_pca[n_traj_points:] if fixed_points.n > 0 else np.array([]) # Create figure fig, ax = plt.subplots(figsize=config.figsize) # Plot trajectories start_idx = 0 for batch_idx in plot_batch_idx: traj_len = len(state_traj[batch_idx, plot_start_time:, :]) end_idx = start_idx + traj_len ax.scatter( traj_pca[start_idx:end_idx, 0], traj_pca[start_idx:end_idx, 1], c="lightblue", s=8.0, alpha=1.0, label="Trajectories" if batch_idx == plot_batch_idx[0] else "", zorder=-1, ) start_idx = end_idx # Plot fixed points if fixed_points.n > 0: # Separate stable and unstable stable_mask = fixed_points.is_stable unstable_mask = ~stable_mask if np.any(unstable_mask): ax.scatter( fps_pca[unstable_mask, 0], fps_pca[unstable_mask, 1], c="red", marker="x", s=200, linewidths=3, label="Unstable Fixed Points", zorder=10, ) if np.any(stable_mask): ax.scatter( fps_pca[stable_mask, 0], fps_pca[stable_mask, 1], c="darkred", marker="o", s=200, edgecolors="black", linewidths=2, label="Stable Fixed Points", zorder=10, ) ax.set_xlabel(config.xlabel) ax.set_ylabel(config.ylabel) ax.set_title(config.title) if config.show_legend: ax.legend() if config.grid: ax.grid(True, alpha=0.3) plt.tight_layout() # Save if path specified if config.save_path: plt.savefig(config.save_path, dpi=300, bbox_inches="tight") print(f" Saved 2D plot to: {config.save_path}") if config.show: plt.show() return fig
[docs] def plot_fixed_points_3d( fixed_points: FixedPoints, state_traj: np.ndarray, config: PlotConfig | None = None, plot_batch_idx: list[int] | None = None, plot_start_time: int = 0, ) -> Figure: """Plot fixed points and trajectories in 3D using PCA. Args: fixed_points: FixedPoints object containing analysis results. state_traj: State trajectories [n_batch x n_time x n_states]. config: Plot configuration. If None, uses default config. plot_batch_idx: Batch indices to plot trajectories. If None, plots first 30. plot_start_time: Starting time index for trajectory plotting. Returns: matplotlib Figure object. Example: >>> from canns.analyzer.slow_points import plot_fixed_points_3d, FixedPoints >>> from canns.analyzer.plotting import PlotConfig >>> config = PlotConfig( ... title="Fixed Points 3D", ... figsize=(12, 10), ... save_path="fps_3d.png" ... ) >>> fig = plot_fixed_points_3d(unique_fps, hiddens, config=config) """ # Default config if config is None: config = PlotConfig( title="Fixed Points (3D PCA)", xlabel="PC 1", ylabel="PC 2", figsize=(12, 10), show_legend=True, ) n_batch, n_time, n_states = state_traj.shape # Default: plot first 30 trajectories if plot_batch_idx is None: plot_batch_idx = list(range(min(30, n_batch))) # Flatten trajectories for PCA all_states = [] for batch_idx in plot_batch_idx: all_states.append(state_traj[batch_idx, plot_start_time:, :]) all_states = np.concatenate(all_states, axis=0) # [n_samples x n_states] # Add fixed points if fixed_points.n > 0: all_states = np.concatenate([all_states, fixed_points.xstar], axis=0) # Perform PCA pca = PCA(n_components=3) all_pca = pca.fit_transform(all_states) # Split back n_traj_points = sum(len(state_traj[i, plot_start_time:, :]) for i in plot_batch_idx) fps_pca = all_pca[n_traj_points:] if fixed_points.n > 0 else np.array([]) # Compute explained variance explained_var = pca.explained_variance_ratio_ total_var = np.sum(explained_var) * 100 # Create figure fig = plt.figure(figsize=config.figsize) ax = fig.add_subplot(111, projection="3d") # Plot trajectories as lines start_idx = 0 for i, batch_idx in enumerate(plot_batch_idx): traj_len = len(state_traj[batch_idx, plot_start_time:, :]) end_idx = start_idx + traj_len traj_segment = all_pca[start_idx:end_idx] ax.plot( traj_segment[:, 0], traj_segment[:, 1], traj_segment[:, 2], c="blue", alpha=0.3, linewidth=0.5, label="RNN Trajectories" if i == 0 else "", ) start_idx = end_idx # Plot fixed points if fixed_points.n > 0: # Separate stable and unstable stable_mask = fixed_points.is_stable unstable_mask = ~stable_mask if np.any(unstable_mask): ax.scatter( fps_pca[unstable_mask, 0], fps_pca[unstable_mask, 1], fps_pca[unstable_mask, 2], c="red", marker="x", s=200, linewidths=3, label="Unstable FPs", zorder=10, ) if np.any(stable_mask): ax.scatter( fps_pca[stable_mask, 0], fps_pca[stable_mask, 1], fps_pca[stable_mask, 2], c="darkred", marker="o", s=200, edgecolors="black", linewidths=2, label="Stable FPs", zorder=10, ) # Labels and title ax.set_xlabel(config.xlabel if config.xlabel else "PC 1") ax.set_ylabel(config.ylabel if config.ylabel else "PC 2") ax.set_zlabel("PC 3") # Add variance info to title title = config.title if title: title += f"\nVariance: {explained_var[0]:.1%}, {explained_var[1]:.1%}, {explained_var[2]:.1%} (total: {total_var:.1f}%)" ax.set_title(title) if config.show_legend: ax.legend() # Set aspect ratio to be equal max_range = ( np.array( [ fps_pca[:, 0].max() - fps_pca[:, 0].min() if fixed_points.n > 0 else 1, fps_pca[:, 1].max() - fps_pca[:, 1].min() if fixed_points.n > 0 else 1, fps_pca[:, 2].max() - fps_pca[:, 2].min() if fixed_points.n > 0 else 1, ] ).max() / 2.0 ) if fixed_points.n > 0: mid_x = (fps_pca[:, 0].max() + fps_pca[:, 0].min()) * 0.5 mid_y = (fps_pca[:, 1].max() + fps_pca[:, 1].min()) * 0.5 mid_z = (fps_pca[:, 2].max() + fps_pca[:, 2].min()) * 0.5 ax.set_xlim(mid_x - max_range, mid_x + max_range) ax.set_ylim(mid_y - max_range, mid_y + max_range) ax.set_zlim(mid_z - max_range, mid_z + max_range) # Print PCA info print(f" PCA explained variance: {explained_var}") print(f" Total variance explained: {total_var:.2f}%") plt.tight_layout() # Save if path specified if config.save_path: plt.savefig(config.save_path, dpi=300, bbox_inches="tight") print(f" Saved 3D plot to: {config.save_path}") if config.show: plt.show() return fig