Source code for src.canns.analyzer.plotting.spikes

"""Spike train visualization helpers."""

from __future__ import annotations

from typing import Any

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

from .config import PlotConfig, PlotConfigs

__all__ = ["raster_plot", "average_firing_rate_plot", "population_activity_heatmap"]


def _ensure_plot_config(
    config: PlotConfig | None,
    factory,
    *,
    kwargs: dict[str, Any] | None = None,
    **defaults: Any,
) -> PlotConfig:
    if config is None:
        defaults.update({"kwargs": kwargs or {}})
        return factory(**defaults)

    if kwargs:
        config_kwargs = config.kwargs or {}
        config_kwargs.update(kwargs)
        config.kwargs = config_kwargs
    return config


[docs] def raster_plot( spike_train: np.ndarray, config: PlotConfig | None = None, *, mode: str = "block", title: str = "Raster Plot", xlabel: str = "Time Step", ylabel: str = "Neuron Index", figsize: tuple[int, int] = (12, 6), color: str = "black", save_path: str | None = None, show: bool = True, **kwargs: Any, ): """Generate a raster plot from a spike train matrix. The explanatory text mirrors the former ``visualize`` module so callers see the same guidance after the reorganisation. Args: spike_train: Boolean/integer array of shape ``(timesteps, neurons)``. config: Optional :class:`PlotConfig` with shared styling options. mode: Either ``"scatter"`` or ``"block"`` to pick the rendering style. title: Plot title when ``config`` is not provided. xlabel: X-axis label when ``config`` is not provided. ylabel: Y-axis label when ``config`` is not provided. figsize: Figure size forwarded to Matplotlib when creating the axes. color: Spike colour (or "on" colour for block mode). save_path: Optional path used to persist the plot. show: Whether to display the plot interactively. **kwargs: Additional keyword arguments passed through to Matplotlib. """ config = _ensure_plot_config( config, PlotConfigs.raster_plot, mode=mode, title=title, xlabel=xlabel, ylabel=ylabel, figsize=figsize, color=color, save_path=save_path, show=show, kwargs=kwargs, ) if not hasattr(config, "mode"): config.mode = mode if spike_train.ndim != 2: raise ValueError(f"Input spike_train must be a 2D array, but got shape {spike_train.shape}") if spike_train.size == 0: raise ValueError("Input spike_train must not be empty.") if config.mode not in {"block", "scatter"}: raise ValueError(f"Invalid mode '{config.mode}'. Choose 'scatter' or 'block'.") fig, ax = plt.subplots(figsize=config.figsize) try: ax.set_title(config.title, fontsize=16, fontweight="bold") ax.set_xlabel(config.xlabel, fontsize=12) ax.set_ylabel(config.ylabel, fontsize=12) if config.mode == "scatter": time_indices, neuron_indices = np.where(spike_train) marker_size = config.kwargs.pop("marker_size", 1.0) ax.scatter( time_indices, neuron_indices, s=marker_size, c=config.color, marker="|", alpha=0.8, **config.to_matplotlib_kwargs(), ) ax.set_xlim(0, spike_train.shape[0]) ax.set_ylim(-1, spike_train.shape[1]) else: data_to_show = spike_train.T cmap = config.kwargs.pop("cmap", ListedColormap(["white", config.color])) ax.imshow( data_to_show, aspect="auto", interpolation="none", cmap=cmap, **config.to_matplotlib_kwargs(), ) ax.set_yticks(np.arange(spike_train.shape[1])) ax.set_yticklabels(np.arange(spike_train.shape[1])) if spike_train.shape[1] > 20: ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True, nbins=10)) if config.save_path: plt.savefig(config.save_path, dpi=300, bbox_inches="tight") print(f"Plot saved to: {config.save_path}") if config.show: plt.show() finally: plt.close(fig) return fig, ax
[docs] def average_firing_rate_plot( spike_train: np.ndarray, dt: float, config: PlotConfig | None = None, *, mode: str = "population", weights: np.ndarray | None = None, title: str = "Average Firing Rate", figsize: tuple[int, int] = (12, 5), save_path: str | None = None, show: bool = True, **kwargs: Any, ): """Calculate and plot average neural activity from a spike train. Args: spike_train: Boolean/integer array of shape ``(timesteps, neurons)``. dt: Simulation time step in seconds. config: Optional :class:`PlotConfig` with styling overrides. mode: One of ``"per_neuron"``, ``"population"`` or ``"weighted_average"``. weights: Neuron-wise weights required for ``"weighted_average"``. title: Plot title when ``config`` is not provided. figsize: Figure size forwarded to Matplotlib when creating the axes. save_path: Optional path used to persist the plot. show: Whether to display the plot interactively. **kwargs: Additional keyword arguments forwarded to Matplotlib. """ config = _ensure_plot_config( config, PlotConfigs.average_firing_rate_plot, mode=mode, title=title, figsize=figsize, save_path=save_path, show=show, kwargs=kwargs, ) if not hasattr(config, "mode"): config.mode = mode if spike_train.ndim != 2: raise ValueError("Input spike_train must be a 2D array.") fig, ax = plt.subplots(figsize=config.figsize) try: num_timesteps, num_neurons = spike_train.shape ax.set_title(config.title, fontsize=16, fontweight="bold") if config.mode == "per_neuron": duration_s = num_timesteps * dt total_spikes_per_neuron = np.sum(spike_train, axis=0) calculated_data = total_spikes_per_neuron / duration_s ax.plot(np.arange(num_neurons), calculated_data, **config.to_matplotlib_kwargs()) ax.set_xlabel("Neuron Index", fontsize=12) ax.set_ylabel("Average Firing Rate (Hz)", fontsize=12) ax.set_xlim(0, num_neurons - 1) elif config.mode == "population": spikes_per_timestep = np.sum(spike_train, axis=1) calculated_data = spikes_per_timestep / dt time_vector = np.arange(num_timesteps) * dt ax.plot(time_vector, calculated_data, **config.to_matplotlib_kwargs()) ax.set_xlabel("Time (s)", fontsize=12) ax.set_ylabel("Total Population Rate (Hz)", fontsize=12) ax.set_xlim(0, time_vector[-1]) elif config.mode == "weighted_average": if weights is None: raise ValueError("'weights' argument is required for 'weighted_average' mode.") if weights.shape != (num_neurons,): raise ValueError( f"Shape of 'weights' {weights.shape} must match num_neurons ({num_neurons})." ) total_spikes_per_timestep = np.sum(spike_train, axis=1) weighted_sum_of_spikes = np.sum(spike_train * weights, axis=1) calculated_data = weighted_sum_of_spikes / (total_spikes_per_timestep + 1e-9) calculated_data[total_spikes_per_timestep == 0] = np.nan time_vector = np.arange(num_timesteps) * dt ax.plot(time_vector, calculated_data, **config.to_matplotlib_kwargs()) ax.set_xlabel("Time (s)", fontsize=12) ax.set_ylabel("Decoded Value (Weighted Average)", fontsize=12) ax.set_xlim(0, time_vector[-1]) else: raise ValueError( f"Invalid mode '{config.mode}'. Choose 'per_neuron', 'population', or 'weighted_average'." ) ax.grid(True, linestyle="--", alpha=0.6) if config.save_path: plt.savefig(config.save_path, dpi=300, bbox_inches="tight") print(f"Plot saved to: {config.save_path}") if config.show: plt.show() finally: plt.close(fig) return fig, ax
[docs] def population_activity_heatmap( activity_data: np.ndarray, dt: float, config: PlotConfig | None = None, *, title: str = "Population Activity", xlabel: str = "Time (s)", ylabel: str = "Neuron Index", figsize: tuple[int, int] = (10, 6), cmap: str = "viridis", save_path: str | None = None, show: bool = True, **kwargs: Any, ): """Generate a heatmap of population firing rate activity over time. This function creates a 2D visualization where each row represents a neuron and each column represents a time point, with color indicating the firing rate or activity level. Args: activity_data: 2D array of shape ``(timesteps, neurons)`` containing firing rates or activity values. dt: Simulation time step in seconds. config: Optional :class:`PlotConfig` with styling overrides. title: Plot title when ``config`` is not provided. xlabel: X-axis label when ``config`` is not provided. ylabel: Y-axis label when ``config`` is not provided. figsize: Figure size forwarded to Matplotlib when creating the axes. cmap: Colormap name (default: "viridis"). save_path: Optional path used to persist the plot. show: Whether to display the plot interactively. **kwargs: Additional keyword arguments forwarded to Matplotlib. Returns: tuple: (figure, axis) objects. Example: >>> import numpy as np >>> from canns.analyzer.plotting.spikes import population_activity_heatmap >>> # Simulate some activity data >>> activity = np.random.rand(1000, 100) # 1000 timesteps, 100 neurons >>> fig, ax = population_activity_heatmap(activity, dt=0.001) """ if config is None: config = PlotConfig( title=title, xlabel=xlabel, ylabel=ylabel, figsize=figsize, save_path=save_path, show=show, kwargs={"cmap": cmap, **kwargs}, ) else: # Merge additional kwargs if provided if kwargs or cmap != "viridis": config_kwargs = config.kwargs or {} config_kwargs.update({"cmap": cmap, **kwargs}) config.kwargs = config_kwargs if activity_data.ndim != 2: raise ValueError( f"Input activity_data must be a 2D array, but got shape {activity_data.shape}" ) if activity_data.size == 0: raise ValueError("Input activity_data must not be empty.") num_timesteps, num_neurons = activity_data.shape fig, ax = plt.subplots(figsize=config.figsize) try: # Create time axis time_axis = np.arange(num_timesteps) * dt # Transpose for proper visualization (neurons × time) activity_transposed = activity_data.T # Extract cmap from kwargs for imshow plot_kwargs = config.to_matplotlib_kwargs() cmap_name = plot_kwargs.pop("cmap", cmap) # Plot heatmap im = ax.imshow( activity_transposed, aspect="auto", extent=[time_axis[0], time_axis[-1], 0, num_neurons], origin="lower", cmap=cmap_name, **plot_kwargs, ) # Configure axes ax.set_title(config.title, fontsize=16, fontweight="bold") ax.set_xlabel(config.xlabel, fontsize=12) ax.set_ylabel(config.ylabel, fontsize=12) # Add colorbar cbar = fig.colorbar(im, ax=ax) cbar.set_label("Activity", fontsize=10) fig.tight_layout() # Save and show if config.save_path: plt.savefig(config.save_path, dpi=300, bbox_inches="tight") print(f"Plot saved to: {config.save_path}") if config.show: plt.show() return fig, ax except Exception as e: plt.close(fig) raise e