Source code for src.canns.analyzer.plotting.config

"""Reusable plotting configuration utilities for analyzer visualizations."""

from dataclasses import dataclass
from typing import Any

import numpy as np

__all__ = ["PlotConfig", "PlotConfigs"]


@dataclass
[docs] class PlotConfig: """Unified configuration class for all plotting helpers in ``canns.analyzer``. This mirrors the behaviour of the previous ``visualize`` module so that reorganising the files does not affect the public API. The attributes map directly to keyword arguments exposed by the high-level plotting functions, allowing users to keep existing configuration objects unchanged after the reorganisation. """
[docs] title: str = ""
[docs] xlabel: str = ""
[docs] ylabel: str = ""
[docs] figsize: tuple[int, int] = (10, 6)
[docs] grid: bool = False
[docs] save_path: str | None = None
[docs] show: bool = True
[docs] time_steps_per_second: int | None = None
[docs] fps: int = 30
[docs] repeat: bool = True
[docs] show_progress_bar: bool = True
[docs] show_legend: bool = True
[docs] color: str = "black"
[docs] clabel: str = "Value"
[docs] kwargs: dict[str, Any] | None = None
[docs] def __post_init__(self) -> None: if self.kwargs is None: self.kwargs = {}
@classmethod
[docs] def for_static_plot(cls, **kwargs: Any) -> "PlotConfig": """Return configuration tailored for static plots.""" config = cls(**kwargs) config.time_steps_per_second = None return config
@classmethod
[docs] def for_animation(cls, time_steps_per_second: int, **kwargs: Any) -> "PlotConfig": """Return configuration tailored for animations.""" return cls(time_steps_per_second=time_steps_per_second, **kwargs)
[docs] def to_matplotlib_kwargs(self) -> dict[str, Any]: """Materialize matplotlib keyword arguments from the config.""" return self.kwargs.copy() if self.kwargs else {}
[docs] class PlotConfigs: """Collection of commonly used plot configurations. These helpers mirror the presets that existed in ``canns.analyzer.visualize`` so that callers relying on them continue to receive the exact same defaults. """ @staticmethod
[docs] def energy_landscape_1d_static(**kwargs: Any) -> PlotConfig: defaults: dict[str, Any] = { "title": "1D Energy Landscape", "xlabel": "Collective Variable / State", "ylabel": "Energy", "figsize": (10, 6), } defaults.update(kwargs) return PlotConfig.for_static_plot(**defaults)
@staticmethod
[docs] def energy_landscape_1d_animation(**kwargs: Any) -> PlotConfig: defaults: dict[str, Any] = { "title": "Evolving 1D Energy Landscape", "xlabel": "Collective Variable / State", "ylabel": "Energy", "figsize": (10, 6), "fps": 30, } time_steps = kwargs.pop("time_steps_per_second", 1000) defaults.update(kwargs) return PlotConfig.for_animation(time_steps, **defaults)
@staticmethod
[docs] def energy_landscape_2d_static(**kwargs: Any) -> PlotConfig: defaults: dict[str, Any] = { "title": "2D Static Landscape", "xlabel": "X-Index", "ylabel": "Y-Index", "clabel": "Value", "figsize": (8, 7), } defaults.update(kwargs) return PlotConfig.for_static_plot(**defaults)
@staticmethod
[docs] def energy_landscape_2d_animation(**kwargs: Any) -> PlotConfig: defaults: dict[str, Any] = { "title": "Evolving 2D Landscape", "xlabel": "X-Index", "ylabel": "Y-Index", "clabel": "Value", "figsize": (8, 7), "fps": 30, } time_steps = kwargs.pop("time_steps_per_second", 1000) defaults.update(kwargs) return PlotConfig.for_animation(time_steps, **defaults)
@staticmethod
[docs] def raster_plot(mode: str = "block", **kwargs: Any) -> PlotConfig: defaults: dict[str, Any] = { "title": "Raster Plot", "xlabel": "Time Step", "ylabel": "Neuron Index", "figsize": (12, 6), "color": "black", } defaults.update(kwargs) config = PlotConfig.for_static_plot(**defaults) config.mode = mode return config
@staticmethod
[docs] def average_firing_rate_plot(mode: str = "per_neuron", **kwargs: Any) -> PlotConfig: defaults: dict[str, Any] = { "title": "Average Firing Rate", "figsize": (12, 5), } defaults.update(kwargs) config = PlotConfig.for_static_plot(**defaults) config.mode = mode return config
@staticmethod
[docs] def theta_population_activity_static(**kwargs: Any) -> PlotConfig: defaults: dict[str, Any] = { "title": "Population Activity with Theta", "xlabel": "Time (s)", "ylabel": "Direction (°)", "figsize": (12, 4), } plot_kwargs: dict[str, Any] = {"cmap": "jet"} plot_kwargs.update(kwargs.pop("kwargs", {})) defaults["kwargs"] = plot_kwargs defaults.update(kwargs) return PlotConfig.for_static_plot(**defaults)
@staticmethod
[docs] def grid_cell_manifold_static(**kwargs: Any) -> PlotConfig: defaults: dict[str, Any] = { "title": "Grid Cell Activity on Manifold", "figsize": (8, 6), } plot_kwargs: dict[str, Any] = {"cmap": "jet", "add_colorbar": True} plot_kwargs.update(kwargs.pop("kwargs", {})) defaults["kwargs"] = plot_kwargs defaults.update(kwargs) return PlotConfig.for_static_plot(**defaults)
@staticmethod
[docs] def theta_sweep_animation(**kwargs: Any) -> PlotConfig: defaults: dict[str, Any] = { "figsize": (12, 3), "fps": 10, "show_progress_bar": True, } animation_kwargs: dict[str, Any] = { "cmap": "jet", "alpha": 0.8, "trajectory_color": "#FFFFFF", "trajectory_outline": "#1A1A1A", "current_marker_color": "#FF2D00", } animation_kwargs.update(kwargs.pop("kwargs", {})) defaults["kwargs"] = animation_kwargs time_steps = kwargs.pop("time_steps_per_second", None) defaults.update(kwargs) defaults["time_steps_per_second"] = time_steps return PlotConfig(**defaults)
@staticmethod
[docs] def tuning_curve( num_bins: int = 50, pref_stim: np.ndarray | None = None, **kwargs: Any, ) -> PlotConfig: defaults: dict[str, Any] = { "title": "Tuning Curve", "xlabel": "Stimulus Value", "ylabel": "Average Firing Rate", "figsize": (10, 6), } defaults.update(kwargs) config = PlotConfig.for_static_plot(**defaults) config.num_bins = num_bins config.pref_stim = pref_stim return config