Source code for src.canns.pipeline.theta_sweep

"""
Theta Sweep Pipeline for External Trajectory Analysis

This module provides a high-level pipeline for experimental scientists to analyze
their trajectory data using CANN theta sweep models without needing to understand
the underlying implementation details.
"""

from pathlib import Path
from typing import Any

import brainpy.math as bm
import numpy as np

from ..analyzer.plotting import PlotConfig
from ..analyzer.theta_sweep import (
    create_theta_sweep_grid_cell_animation,
    plot_population_activity_with_theta,
)
from ..models.basic.theta_sweep_model import (
    DirectionCellNetwork,
    GridCellNetwork,
    calculate_theta_modulation,
)
from ..task.open_loop_navigation import OpenLoopNavigationTask
from ._base import Pipeline


[docs] class ThetaSweepPipeline(Pipeline): """ High-level pipeline for theta sweep analysis of external trajectory data. This pipeline abstracts the complex workflow of running CANN theta sweep models on experimental trajectory data, making it accessible to researchers who want to analyze neural responses without diving into implementation details. Example: ```python # Simple usage - just provide trajectory data pipeline = ThetaSweepPipeline( trajectory_data=positions, # shape: (n_steps, 2) times=times # shape: (n_steps,) ) results = pipeline.run(output_dir="my_results/") print(f"Animation saved to: {results['animation_path']}") ``` """ def __init__( self, trajectory_data: np.ndarray, times: np.ndarray | None = None, env_size: float = 2.0, dt: float = 0.001, direction_cell_params: dict[str, Any] | None = None, grid_cell_params: dict[str, Any] | None = None, theta_params: dict[str, Any] | None = None, spatial_nav_params: dict[str, Any] | None = None, ): """ Initialize the theta sweep pipeline. Args: trajectory_data: Position coordinates with shape (n_steps, 2) for 2D trajectories times: Optional time array with shape (n_steps,). If None, uniform time steps will be used env_size: Environment size (assumes square environment) dt: Simulation time step direction_cell_params: Parameters for DirectionCellNetwork. If None, uses defaults grid_cell_params: Parameters for GridCellNetwork. If None, uses defaults theta_params: Parameters for theta modulation. If None, uses defaults spatial_nav_params: Additional parameters for OpenLoopNavigationTask. If None, uses defaults """ super().__init__() # Store trajectory data
[docs] self.trajectory_data = np.array(trajectory_data)
[docs] self.times = np.array(times) if times is not None else None
[docs] self.env_size = env_size
[docs] self.dt = dt
# Validate trajectory data self._validate_trajectory_data() # Set up default parameters
[docs] self.direction_cell_params = self._get_default_direction_cell_params()
if direction_cell_params: self.direction_cell_params.update(direction_cell_params)
[docs] self.grid_cell_params = self._get_default_grid_cell_params()
if grid_cell_params: self.grid_cell_params.update(grid_cell_params)
[docs] self.theta_params = self._get_default_theta_params()
if theta_params: self.theta_params.update(theta_params)
[docs] self.spatial_nav_params = self._get_default_spatial_nav_params()
if spatial_nav_params: self.spatial_nav_params.update(spatial_nav_params) # Initialize components
[docs] self.spatial_nav_task = None
[docs] self.direction_network = None
[docs] self.grid_network = None
def _validate_trajectory_data(self): """ Validate input trajectory data format and dimensions. Checks: - Trajectory is 2D array (n_steps, n_dims) - Only 2D spatial trajectories (n_dims=2) - At least 2 time steps - Times array matches trajectory length if provided Raises: ValueError: If validation fails """ if self.trajectory_data.ndim != 2: raise ValueError("trajectory_data must be a 2D array with shape (n_steps, n_dims)") n_steps, n_dims = self.trajectory_data.shape if n_dims != 2: raise ValueError("Currently only 2D trajectories are supported") if n_steps < 2: raise ValueError("trajectory_data must contain at least 2 time steps") if self.times is not None: if self.times.shape[0] != n_steps: raise ValueError("times array length must match trajectory_data length") def _get_default_direction_cell_params(self) -> dict[str, Any]: """ Get default parameters for DirectionCellNetwork initialization. Returns: dict: Default parameters including: - num: 100 neurons - adaptation_strength: 15 for SFA dynamics - noise_strength: 0.0 (no noise) """ return { "num": 100, "adaptation_strength": 15, "noise_strength": 0.0, } def _get_default_grid_cell_params(self) -> dict[str, Any]: """ Get default parameters for GridCellNetwork initialization. Returns: dict: Default parameters including: - num_gc_x: 100 neurons per dimension (100x100 grid) - adaptation_strength: 8 for SFA dynamics - mapping_ratio: 5 (controls grid spacing) - noise_strength: 0.0 (no noise) """ return { "num_gc_x": 100, "adaptation_strength": 8, "mapping_ratio": 5, "noise_strength": 0.0, } def _get_default_theta_params(self) -> dict[str, Any]: """ Get default parameters for theta oscillation modulation. Returns: dict: Default parameters including: - theta_strength_hd: 1.0 for direction cells - theta_strength_gc: 0.5 for grid cells - theta_cycle_len: 100.0 ms per cycle """ return { "theta_strength_hd": 1.0, "theta_strength_gc": 0.5, "theta_cycle_len": 100.0, } def _get_default_spatial_nav_params(self) -> dict[str, Any]: """ Get default parameters for OpenLoopNavigationTask initialization. Returns: dict: Default parameters including environment size, dt, etc. """ return { "width": self.env_size, "height": self.env_size, "dt": self.dt, "progress_bar": False, } def _setup_open_loop_navigation_task(self): """ Set up and configure the spatial navigation task with trajectory data. Creates OpenLoopNavigationTask, imports external trajectory data, and calculates theta sweep parameters (velocity, angular speed, etc.). """ # Calculate duration from trajectory data if self.times is not None: duration = self.times[-1] - self.times[0] else: duration = len(self.trajectory_data) * self.dt # Create spatial navigation task self.spatial_nav_task = OpenLoopNavigationTask(duration=duration, **self.spatial_nav_params) # Import external trajectory data self.spatial_nav_task.import_data( position_data=self.trajectory_data, times=self.times, dt=self.dt ) # Calculate theta sweep data self.spatial_nav_task.calculate_theta_sweep_data() def _setup_neural_networks(self): """ Initialize and configure direction cell and grid cell networks. Creates DirectionCellNetwork and GridCellNetwork instances with configured parameters and initializes their states. """ # Create direction cell network self.direction_network = DirectionCellNetwork(**self.direction_cell_params) # Create grid cell network (ensure consistency with direction network) grid_params = self.grid_cell_params.copy() grid_params["num_dc"] = self.direction_network.num self.grid_network = GridCellNetwork(**grid_params) def _run_simulation(self): """ Run the main theta sweep simulation loop. Executes time-stepped simulation of direction and grid cell networks with theta modulation. Records neural activity, theta phase, and decoded positions at each time step. Returns: dict: Simulation results containing: - dc_activity: Direction cell firing rates over time - gc_activity: Grid cell firing rates over time - gc_center_phase: Grid cell bump centers in phase space - gc_center_position: Decoded positions from grid cells - theta_phase: Theta oscillation phase over time """ # Set BrainState environment bm.set_dt(dt=1.0) # Extract data from spatial navigation task snt_data = self.spatial_nav_task.data position = snt_data.position direction = snt_data.hd_angle linear_speed_gains = snt_data.linear_speed_gains ang_speed_gains = snt_data.ang_speed_gains def run_step(i, pos, hd_angle, linear_gain, ang_gain): """Single simulation step.""" theta_phase, theta_modulation_hd, theta_modulation_gc = calculate_theta_modulation( time_step=i, linear_gain=linear_gain, ang_gain=ang_gain, theta_strength_hd=self.theta_params["theta_strength_hd"], theta_strength_gc=self.theta_params["theta_strength_gc"], theta_cycle_len=self.theta_params["theta_cycle_len"], dt=self.dt, ) # Update direction cell network self.direction_network(hd_angle, theta_modulation_hd) dc_activity = self.direction_network.r.value # Update grid cell network self.grid_network(pos, dc_activity, theta_modulation_gc) gc_activity = self.grid_network.r.value return ( self.grid_network.center_position.value, self.direction_network.center.value, gc_activity, self.grid_network.gc_bump.value, dc_activity, theta_phase, theta_modulation_hd, theta_modulation_gc, ) # Run compiled simulation loop results = bm.for_loop( run_step, bm.arange(len(position)), position, direction, linear_speed_gains, ang_speed_gains, pbar=None, ) # Unpack results ( internal_position, internal_direction, gc_activity, gc_bump, dc_activity, theta_phase, theta_modulation_hd, theta_modulation_gc, ) = results # Store simulation results self.simulation_results = { "internal_position": internal_position, "internal_direction": internal_direction, "gc_activity": gc_activity, "gc_bump": gc_bump, "dc_activity": dc_activity, "theta_phase": theta_phase, "theta_modulation_hd": theta_modulation_hd, "theta_modulation_gc": theta_modulation_gc, "position": position, "direction": direction, "linear_speed_gains": linear_speed_gains, "ang_speed_gains": ang_speed_gains, "time_steps": self.spatial_nav_task.run_steps, }
[docs] def run( self, output_dir: str | Path = "theta_sweep_results", save_animation: bool = True, save_plots: bool = True, show_plots: bool = False, animation_fps: int = 10, animation_dpi: int = 120, verbose: bool = True, ) -> dict[str, Any]: """ Run the complete theta sweep pipeline. Args: output_dir: Directory to save output files save_animation: Whether to save the theta sweep animation save_plots: Whether to save analysis plots show_plots: Whether to display plots interactively animation_fps: Frame rate for animation animation_dpi: DPI for animation output verbose: Whether to print progress messages Returns: Dictionary containing paths to generated files and analysis data """ self.reset() if verbose: print("🚀 Starting Theta Sweep Pipeline...") # Create output directory output_path = self.prepare_output_dir(output_dir) # Setup pipeline components if verbose: print("📊 Setting up spatial navigation task...") self._setup_open_loop_navigation_task() if verbose: print("🧠 Setting up neural networks...") self._setup_neural_networks() if verbose: print("⚡ Running theta sweep simulation...") self._run_simulation() # Generate outputs outputs = {"data": self.simulation_results} if save_plots or show_plots: outputs.update(self._generate_plots(output_path, show_plots, verbose)) if save_animation: outputs.update( self._generate_animation(output_path, animation_fps, animation_dpi, verbose) ) if verbose: print("✅ Pipeline completed successfully!") print(f"📁 Results saved to: {output_path.absolute()}") return self.set_results(outputs)
def _generate_plots(self, output_path: Path, show_plots: bool, verbose: bool) -> dict[str, str]: """ Generate analysis plots for theta sweep results. Creates trajectory analysis and population activity visualizations. Args: output_path: Directory to save plots show_plots: Whether to display plots interactively verbose: Whether to print progress messages Returns: dict: Mapping of plot names to file paths """ plot_outputs = {} # Trajectory analysis if verbose: print("📈 Generating trajectory analysis...") trajectory_path = output_path / "trajectory_analysis.png" self.spatial_nav_task.show_trajectory_analysis( save_path=str(trajectory_path), show=show_plots, smooth_window=50 ) plot_outputs["trajectory_analysis"] = str(trajectory_path) # Population activity with theta if verbose: print("📊 Generating population activity plot...") config_pop = PlotConfig( title="Direction Cell Population Activity with Theta", xlabel="Time (s)", ylabel="Direction (°)", figsize=(10, 4), show=show_plots, save_path=str(output_path / "population_activity.png"), ) plot_population_activity_with_theta( time_steps=self.simulation_results["time_steps"] * self.dt, theta_phase=self.simulation_results["theta_phase"], net_activity=self.simulation_results["dc_activity"], direction=self.simulation_results["direction"], config=config_pop, add_lines=True, atol=5e-2, ) plot_outputs["population_activity"] = str(output_path / "population_activity.png") return plot_outputs def _generate_animation( self, output_path: Path, fps: int, dpi: int, verbose: bool ) -> dict[str, str]: """ Generate theta sweep animation showing neural dynamics over time. Creates animated visualization of direction and grid cell activity with theta phase modulation. Args: output_path: Directory to save animation fps: Frames per second for animation dpi: Resolution for animation frames verbose: Whether to print progress messages Returns: dict: Mapping containing 'animation' key with file path """ animation_path = output_path / "theta_sweep_animation.gif" config_animation = PlotConfig( figsize=(12, 3), fps=fps, save_path=str(animation_path), show=False, ) if verbose: print("🎬 Creating theta sweep animation...") import sys sys.stdout.flush() # Ensure message is printed before animation starts # Brief pause to ensure message ordering import time time.sleep(0.01) create_theta_sweep_grid_cell_animation( position_data=self.simulation_results["position"], direction_data=self.simulation_results["direction"], dc_activity_data=self.simulation_results["dc_activity"], gc_activity_data=self.simulation_results["gc_activity"], gc_network=self.grid_network, env_size=self.env_size, mapping_ratio=self.grid_cell_params["mapping_ratio"], dt=self.dt, config=config_animation, n_step=10, show_progress_bar=verbose, render_backend="auto", output_dpi=dpi, render_worker_batch_size=2, ) return {"animation_path": str(animation_path)}
# Convenience functions for common use cases
[docs] def load_trajectory_from_csv( filepath: str | Path, x_col: str = "x", y_col: str = "y", time_col: str | None = "time", **kwargs, ) -> dict[str, Any]: """ Load trajectory data from CSV file and run theta sweep analysis. Args: filepath: Path to CSV file x_col: Column name for x coordinates y_col: Column name for y coordinates time_col: Column name for time data (optional) **kwargs: Additional parameters passed to ThetaSweepPipeline Returns: Dictionary containing analysis results and file paths """ import pandas as pd df = pd.read_csv(filepath) trajectory_data = df[[x_col, y_col]].values times = df[time_col].values if time_col and time_col in df.columns else None pipeline = ThetaSweepPipeline(trajectory_data, times, **kwargs) return pipeline.run(verbose=True)
[docs] def batch_process_trajectories( trajectory_list: list, output_base_dir: str = "batch_results", **kwargs ) -> dict[str, dict[str, Any]]: """ Process multiple trajectories in batch. Args: trajectory_list: List of (trajectory_data, times) tuples or trajectory_data arrays output_base_dir: Base directory for batch results **kwargs: Additional parameters passed to ThetaSweepPipeline Returns: Dictionary mapping trajectory indices to results """ batch_results = {} for i, trajectory_input in enumerate(trajectory_list): print(f"\n🔄 Processing trajectory {i + 1}/{len(trajectory_list)}...") if isinstance(trajectory_input, tuple): trajectory_data, times = trajectory_input else: trajectory_data, times = trajectory_input, None output_dir = Path(output_base_dir) / f"trajectory_{i:03d}" try: pipeline = ThetaSweepPipeline(trajectory_data, times, **kwargs) results = pipeline.run(output_dir=str(output_dir), verbose=False) batch_results[f"trajectory_{i:03d}"] = results print(f"✅ Trajectory {i + 1} completed successfully") except Exception as e: print(f"❌ Error processing trajectory {i + 1}: {e}") batch_results[f"trajectory_{i:03d}"] = {"error": str(e)} return batch_results