Source code for src.canns.analyzer.experimental_data.cann2d

import logging
import multiprocessing as mp
import numbers
import os
from dataclasses import dataclass
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
from canns_lib.ripser import ripser
from matplotlib import animation, cm, gridspec
from numpy.exceptions import AxisError

# from ripser import ripser
from scipy import signal
from scipy.ndimage import (
    _nd_image,
    _ni_support,
    binary_closing,
    gaussian_filter,
    gaussian_filter1d,
)
from scipy.ndimage._filters import _invalid_origin
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import lsmr
from scipy.spatial.distance import pdist, squareform
from scipy.stats import binned_statistic_2d, multivariate_normal
from sklearn import preprocessing
from tqdm import tqdm

from canns.analyzer.plotting.jupyter_utils import (
    display_animation_in_jupyter,
    is_jupyter_environment,
)

# Import PlotConfig for unified plotting
from ..plotting import PlotConfig


# ==================== Configuration Classes ====================
@dataclass
[docs] class SpikeEmbeddingConfig: """Configuration for spike train embedding."""
[docs] res: int = 100000
[docs] dt: int = 1000
[docs] sigma: int = 5000
[docs] smooth: bool = True
[docs] speed_filter: bool = True
[docs] min_speed: float = 2.5
@dataclass
[docs] class TDAConfig: """Configuration for Topological Data Analysis."""
[docs] dim: int = 6
[docs] num_times: int = 5
[docs] active_times: int = 15000
[docs] k: int = 1000
[docs] n_points: int = 1200
[docs] metric: str = "cosine"
[docs] nbs: int = 800
[docs] maxdim: int = 1
[docs] coeff: int = 47
[docs] show: bool = True
[docs] do_shuffle: bool = False
[docs] num_shuffles: int = 1000
[docs] progress_bar: bool = True
@dataclass
[docs] class CANN2DPlotConfig(PlotConfig): """Specialized PlotConfig for CANN2D visualizations.""" # 3D projection specific parameters
[docs] zlabel: str = "Component 3"
[docs] dpi: int = 300
# Torus animation specific parameters
[docs] numangsint: int = 51
[docs] r1: float = 1.5 # Major radius
[docs] r2: float = 1.0 # Minor radius
[docs] window_size: int = 300
[docs] frame_step: int = 5
[docs] n_frames: int = 20
@classmethod
[docs] def for_projection_3d(cls, **kwargs) -> "CANN2DPlotConfig": """Create configuration for 3D projection plots.""" defaults = { "title": "3D Data Projection", "xlabel": "Component 1", "ylabel": "Component 2", "zlabel": "Component 3", "figsize": (10, 8), "dpi": 300, } defaults.update(kwargs) return cls.for_static_plot(**defaults)
@classmethod
[docs] def for_torus_animation(cls, **kwargs) -> "CANN2DPlotConfig": """Create configuration for 3D torus bump animations.""" defaults = { "title": "3D Bump on Torus", "figsize": (8, 8), "fps": 5, "repeat": True, "show_progress_bar": True, "numangsint": 51, "r1": 1.5, "r2": 1.0, "window_size": 300, "frame_step": 5, "n_frames": 20, } defaults.update(kwargs) time_steps = kwargs.get("time_steps_per_second", 1000) config = cls.for_animation(time_steps, **defaults) # Add torus-specific attributes config.numangsint = defaults["numangsint"] config.r1 = defaults["r1"] config.r2 = defaults["r2"] config.window_size = defaults["window_size"] config.frame_step = defaults["frame_step"] config.n_frames = defaults["n_frames"] return config
# ==================== Constants ====================
[docs] class Constants: """Constants used throughout CANN2D analysis."""
[docs] DEFAULT_FIGSIZE = (10, 8)
[docs] DEFAULT_DPI = 300
[docs] GAUSSIAN_SIGMA_FACTOR = 100
[docs] SPEED_CONVERSION_FACTOR = 100
[docs] TIME_CONVERSION_FACTOR = 0.01
[docs] MULTIPROCESSING_CORES = 4
# ==================== Custom Exceptions ====================
[docs] class CANN2DError(Exception): """Base exception for CANN2D analysis errors.""" pass
[docs] class DataLoadError(CANN2DError): """Raised when data loading fails.""" pass
[docs] class ProcessingError(CANN2DError): """Raised when data processing fails.""" pass
try: from numba import jit, njit, prange
[docs] HAS_NUMBA = True
except ImportError: HAS_NUMBA = False print( "Using numba for FAST CANN2D analysis, now using pure numpy implementation.", "Try numba by `pip install numba` to speed up the process.", ) # Create dummy decorators if numba is not available def jit(*args, **kwargs): def decorator(func): return func return decorator def njit(*args, **kwargs): def decorator(func): return func return decorator def prange(x): return range(x)
[docs] def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs): """ Load and preprocess spike train data from npz file. This function converts raw spike times into a time-binned spike matrix, optionally applying Gaussian smoothing and filtering based on animal movement speed. Parameters: spike_trains : dict containing 'spike', 't', and optionally 'x', 'y'. config : SpikeEmbeddingConfig, optional configuration object **kwargs : backward compatibility parameters Returns: spikes_bin (ndarray): Binned and optionally smoothed spike matrix of shape (T, N). xx (ndarray, optional): X coordinates (if speed_filter=True). yy (ndarray, optional): Y coordinates (if speed_filter=True). tt (ndarray, optional): Time points (if speed_filter=True). """ # Handle backward compatibility and configuration if config is None: config = SpikeEmbeddingConfig( res=kwargs.get("res", 100000), dt=kwargs.get("dt", 1000), sigma=kwargs.get("sigma", 5000), smooth=kwargs.get("smooth0", True), speed_filter=kwargs.get("speed0", True), min_speed=kwargs.get("min_speed", 2.5), ) try: # Step 1: Extract and filter spike data spikes_filtered = _extract_spike_data(spike_trains, config) # Step 2: Create time bins time_bins = _create_time_bins(spike_trains["t"], config) # Step 3: Bin spike data spikes_bin = _bin_spike_data(spikes_filtered, time_bins, config) # Step 4: Apply temporal smoothing if requested if config.smooth: spikes_bin = _apply_temporal_smoothing(spikes_bin, config) # Step 5: Apply speed filtering if requested if config.speed_filter: return _apply_speed_filtering(spikes_bin, spike_trains, config) return spikes_bin except Exception as e: raise ProcessingError(f"Failed to embed spike trains: {e}") from e
def _extract_spike_data( spike_trains: dict[str, Any], config: SpikeEmbeddingConfig ) -> dict[int, np.ndarray]: """Extract and filter spike data within time window.""" try: # Handle different spike data formats spike_data = spike_trains["spike"] if hasattr(spike_data, "item") and callable(spike_data.item): # numpy array with .item() method (from npz file) spikes_all = spike_data[()] elif isinstance(spike_data, dict): # Already a dictionary spikes_all = spike_data elif isinstance(spike_data, list | np.ndarray): # List or array format spikes_all = spike_data else: # Try direct access spikes_all = spike_data t = spike_trains["t"] min_time0 = np.min(t) max_time0 = np.max(t) # Extract spike intervals for each cell if isinstance(spikes_all, dict): # Dictionary format spikes = {} for i, key in enumerate(spikes_all.keys()): s = np.array(spikes_all[key]) spikes[i] = s[(s >= min_time0) & (s < max_time0)] else: # List/array format cell_inds = np.arange(len(spikes_all)) spikes = {} for i, m in enumerate(cell_inds): s = np.array(spikes_all[m]) if len(spikes_all[m]) > 0 else np.array([]) # Filter spikes within time window if len(s) > 0: spikes[i] = s[(s >= min_time0) & (s < max_time0)] else: spikes[i] = np.array([]) return spikes except KeyError as e: raise DataLoadError(f"Missing required data key: {e}") from e except Exception as e: raise ProcessingError(f"Error extracting spike data: {e}") from e def _create_time_bins(t: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray: """Create time bins for spike discretization.""" min_time0 = np.min(t) max_time0 = np.max(t) min_time = min_time0 * config.res max_time = max_time0 * config.res return np.arange(np.floor(min_time), np.ceil(max_time) + 1, config.dt) def _bin_spike_data( spikes: dict[int, np.ndarray], time_bins: np.ndarray, config: SpikeEmbeddingConfig ) -> np.ndarray: """Convert spike times to binned spike matrix.""" min_time = time_bins[0] max_time = time_bins[-1] spikes_bin = np.zeros((len(time_bins), len(spikes)), dtype=int) for n in spikes: spike_times = np.array(spikes[n] * config.res - min_time, dtype=int) # Filter valid spike times spike_times = spike_times[(spike_times < (max_time - min_time)) & (spike_times > 0)] spike_times = np.array(spike_times / config.dt, int) # Bin spikes for j in spike_times: if j < len(time_bins): spikes_bin[j, n] += 1 return spikes_bin def _apply_temporal_smoothing(spikes_bin: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray: """Apply Gaussian temporal smoothing to spike matrix.""" # Calculate smoothing parameters (legacy implementation used custom kernel) # Current implementation uses scipy's gaussian_filter1d for better performance # Apply smoothing (simplified version - could be further optimized) smoothed = np.zeros((spikes_bin.shape[0], spikes_bin.shape[1])) # Use scipy's gaussian_filter1d for better performance sigma_bins = config.sigma / config.dt for n in range(spikes_bin.shape[1]): smoothed[:, n] = gaussian_filter1d( spikes_bin[:, n].astype(float), sigma=sigma_bins, mode="constant" ) # Normalize normalization_factor = 1 / np.sqrt(2 * np.pi * (config.sigma / config.res) ** 2) return smoothed * normalization_factor def _apply_speed_filtering( spikes_bin: np.ndarray, spike_trains: dict[str, Any], config: SpikeEmbeddingConfig ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Apply speed-based filtering to spike data.""" try: xx, yy, tt_pos, speed = _load_pos( spike_trains["t"], spike_trains["x"], spike_trains["y"], res=config.res, dt=config.dt ) valid = speed > config.min_speed return (spikes_bin[valid, :], xx[valid], yy[valid], tt_pos[valid]) except KeyError as e: raise DataLoadError(f"Missing position data for speed filtering: {e}") from e except Exception as e: raise ProcessingError(f"Error in speed filtering: {e}") from e
[docs] def plot_projection( reduce_func, embed_data, config: CANN2DPlotConfig | None = None, title="Projection (3D)", xlabel="Component 1", ylabel="Component 2", zlabel="Component 3", save_path=None, show=True, dpi=300, figsize=(10, 8), **kwargs, ): """ Plot a 3D projection of the embedded data. Parameters: reduce_func (callable): Function to reduce the dimensionality of the data. embed_data (ndarray): Data to be projected. config (PlotConfig, optional): Configuration object for unified plotting parameters **kwargs: backward compatibility parameters title (str): Title of the plot. xlabel (str): Label for the x-axis. ylabel (str): Label for the y-axis. zlabel (str): Label for the z-axis. save_path (str, optional): Path to save the plot. If None, plot will not be saved. show (bool): Whether to display the plot. dpi (int): Dots per inch for saving the figure. figsize (tuple): Size of the figure. Returns: fig: The created figure object. """ # Handle backward compatibility and configuration if config is None: config = CANN2DPlotConfig.for_projection_3d( title=title, xlabel=xlabel, ylabel=ylabel, zlabel=zlabel, save_path=save_path, show=show, figsize=figsize, dpi=dpi, **kwargs, ) reduced_data = reduce_func(embed_data[::5]) fig = plt.figure(figsize=config.figsize) ax = fig.add_subplot(111, projection="3d") ax.scatter(reduced_data[:, 0], reduced_data[:, 1], reduced_data[:, 2], s=1, alpha=0.5) ax.set_title(config.title) ax.set_xlabel(config.xlabel) ax.set_ylabel(config.ylabel) ax.set_zlabel(config.zlabel) if config.save_path is None and config.show is None: raise ValueError("Either save path or show must be provided.") if config.save_path: plt.savefig(config.save_path, dpi=config.dpi) if config.show: plt.show() plt.close(fig) return fig
[docs] def tda_vis(embed_data: np.ndarray, config: TDAConfig | None = None, **kwargs) -> dict[str, Any]: """ Topological Data Analysis visualization with optional shuffle testing. Parameters: embed_data : ndarray Embedded spike train data. config : TDAConfig, optional Configuration object with all TDA parameters **kwargs : backward compatibility parameters Returns: dict : Dictionary containing: - persistence: persistence diagrams from real data - indstemp: indices of sampled points - movetimes: selected time points - n_points: number of sampled points - shuffle_max: shuffle analysis results (if do_shuffle=True, otherwise None) """ # Handle backward compatibility and configuration if config is None: config = TDAConfig( dim=kwargs.get("dim", 6), num_times=kwargs.get("num_times", 5), active_times=kwargs.get("active_times", 15000), k=kwargs.get("k", 1000), n_points=kwargs.get("n_points", 1200), metric=kwargs.get("metric", "cosine"), nbs=kwargs.get("nbs", 800), maxdim=kwargs.get("maxdim", 1), coeff=kwargs.get("coeff", 47), show=kwargs.get("show", True), do_shuffle=kwargs.get("do_shuffle", False), num_shuffles=kwargs.get("num_shuffles", 1000), progress_bar=kwargs.get("progress_bar", True), ) try: # Compute persistent homology for real data print("Computing persistent homology for real data...") real_persistence = _compute_real_persistence(embed_data, config) # Perform shuffle analysis if requested shuffle_max = None if config.do_shuffle: shuffle_max = _perform_shuffle_analysis(embed_data, config) # Visualization _handle_visualization(real_persistence["persistence"], shuffle_max, config) # Return results as dictionary return { "persistence": real_persistence["persistence"], "indstemp": real_persistence["indstemp"], "movetimes": real_persistence["movetimes"], "n_points": real_persistence["n_points"], "shuffle_max": shuffle_max, } except Exception as e: raise ProcessingError(f"TDA analysis failed: {e}") from e
def _compute_real_persistence(embed_data: np.ndarray, config: TDAConfig) -> dict[str, Any]: """Compute persistent homology for real data with progress tracking.""" logging.info("Processing real data - Starting TDA analysis (5 steps)") # Step 1: Time point downsampling logging.info("Step 1/5: Time point downsampling") times_cube = _downsample_timepoints(embed_data, config.num_times) # Step 2: Select most active time points logging.info("Step 2/5: Selecting active time points") movetimes = _select_active_timepoints(embed_data, times_cube, config.active_times) # Step 3: PCA dimensionality reduction logging.info("Step 3/5: PCA dimensionality reduction") dimred = _apply_pca_reduction(embed_data, movetimes, config.dim) # Step 4: Point cloud sampling (denoising) logging.info("Step 4/5: Point cloud denoising") indstemp = _apply_denoising(dimred, config) # Step 5: Compute persistent homology logging.info("Step 5/5: Computing persistent homology") persistence = _compute_persistence_homology(dimred, indstemp, config) logging.info("TDA analysis completed successfully") # Return all necessary data in dictionary format return { "persistence": persistence, "indstemp": indstemp, "movetimes": movetimes, "n_points": config.n_points, } def _downsample_timepoints(embed_data: np.ndarray, num_times: int) -> np.ndarray: """Downsample timepoints for computational efficiency.""" return np.arange(0, embed_data.shape[0], num_times) def _select_active_timepoints( embed_data: np.ndarray, times_cube: np.ndarray, active_times: int ) -> np.ndarray: """Select most active timepoints based on total activity.""" activity_scores = np.sum(embed_data[times_cube, :], 1) # Match external TDAvis: sort indices first, then map to times_cube movetimes = np.sort(np.argsort(activity_scores)[-active_times:]) return times_cube[movetimes] def _apply_pca_reduction(embed_data: np.ndarray, movetimes: np.ndarray, dim: int) -> np.ndarray: """Apply PCA dimensionality reduction.""" scaled_data = preprocessing.scale(embed_data[movetimes, :]) dimred, *_ = _pca(scaled_data, dim=dim) return dimred def _apply_denoising(dimred: np.ndarray, config: TDAConfig) -> np.ndarray: """Apply point cloud denoising.""" indstemp, *_ = _sample_denoising( dimred, k=config.k, num_sample=config.n_points, omega=1, # Match external TDAvis: uses 1, not default 0.2 metric=config.metric, ) return indstemp def _compute_persistence_homology( dimred: np.ndarray, indstemp: np.ndarray, config: TDAConfig ) -> dict[str, Any]: """Compute persistent homology using ripser.""" d = _second_build(dimred, indstemp, metric=config.metric, nbs=config.nbs) np.fill_diagonal(d, 0) return ripser( d, maxdim=config.maxdim, coeff=config.coeff, do_cocycles=True, distance_matrix=True, progress_bar=config.progress_bar, ) def _perform_shuffle_analysis(embed_data: np.ndarray, config: TDAConfig) -> dict[int, Any]: """Perform shuffle analysis with progress tracking.""" print(f"\nStarting shuffle analysis with {config.num_shuffles} iterations...") # Create parameters dict for shuffle analysis shuffle_params = { "dim": config.dim, "num_times": config.num_times, "active_times": config.active_times, "k": config.k, "n_points": config.n_points, "metric": config.metric, "nbs": config.nbs, "maxdim": config.maxdim, "coeff": config.coeff, } shuffle_max = _run_shuffle_analysis( embed_data, num_shuffles=config.num_shuffles, num_cores=Constants.MULTIPROCESSING_CORES, progress_bar=config.progress_bar, **shuffle_params, ) # Print shuffle analysis summary _print_shuffle_summary(shuffle_max) return shuffle_max def _print_shuffle_summary(shuffle_max: dict[int, Any]) -> None: """Print summary of shuffle analysis results.""" print("\nSummary of shuffle-based analysis:") for dim_idx in [0, 1, 2]: if shuffle_max and dim_idx in shuffle_max and shuffle_max[dim_idx]: values = shuffle_max[dim_idx] print( f"H{dim_idx}: {len(values)} valid iterations | " f"Mean maximum persistence: {np.mean(values):.4f} | " f"99.9th percentile: {np.percentile(values, 99.9):.4f}" ) def _handle_visualization( real_persistence: dict[str, Any], shuffle_max: dict[int, Any] | None, config: TDAConfig ) -> None: """Handle visualization based on configuration.""" if config.show: if config.do_shuffle and shuffle_max is not None: _plot_barcode_with_shuffle(real_persistence, shuffle_max) else: _plot_barcode(real_persistence) plt.show() else: plt.close() def _load_pos(t, x, y, res=100000, dt=1000): """ Compute animal position and speed from spike data file. Interpolates animal positions to match spike time bins and computes smoothed velocity vectors and speed. Parameters: t (ndarray): Time points of the spikes (in seconds). x (ndarray): X coordinates of the animal's position. y (ndarray): Y coordinates of the animal's position. res (int): Time scaling factor to align with spike resolution. dt (int): Temporal bin size in microseconds. Returns: xx (ndarray): Interpolated x positions. yy (ndarray): Interpolated y positions. tt (ndarray): Corresponding time points (in seconds). speed (ndarray): Speed at each time point (in cm/s). """ min_time0 = np.min(t) max_time0 = np.max(t) times = np.where((t >= min_time0) & (t < max_time0)) x = x[times] y = y[times] t = t[times] min_time = min_time0 * res max_time = max_time0 * res tt = np.arange(np.floor(min_time), np.ceil(max_time) + 1, dt) / res idt = np.concatenate(([0], np.digitize(t[1:-1], tt[:]) - 1, [len(tt) + 1])) idtt = np.digitize(np.arange(len(tt)), idt) - 1 idx = np.concatenate((np.unique(idtt), [np.max(idtt) + 1])) divisor = np.bincount(idtt) steps = 1.0 / divisor[divisor > 0] N = np.max(divisor) ranges = np.multiply(np.arange(N)[np.newaxis, :], steps[:, np.newaxis]) ranges[ranges >= 1] = np.nan rangesx = x[idx[:-1], np.newaxis] + np.multiply( ranges, (x[idx[1:]] - x[idx[:-1]])[:, np.newaxis] ) xx = rangesx[~np.isnan(ranges)] rangesy = y[idx[:-1], np.newaxis] + np.multiply( ranges, (y[idx[1:]] - y[idx[:-1]])[:, np.newaxis] ) yy = rangesy[~np.isnan(ranges)] xxs = _gaussian_filter1d(xx - np.min(xx), sigma=100) yys = _gaussian_filter1d(yy - np.min(yy), sigma=100) dx = (xxs[1:] - xxs[:-1]) * 100 dy = (yys[1:] - yys[:-1]) * 100 speed = np.sqrt(dx**2 + dy**2) / 0.01 speed = np.concatenate(([speed[0]], speed)) return xx, yy, tt, speed def _gaussian_filter1d( input, sigma, axis=-1, order=0, output=None, mode="reflect", cval=0.0, truncate=4.0, *, radius=None, ): """1-D Gaussian filter. Parameters ---------- %(input)s sigma : scalar standard deviation for Gaussian kernel %(axis)s order : int, optional An order of 0 corresponds to convolution with a Gaussian kernel. A positive order corresponds to convolution with that derivative of a Gaussian. %(output)s %(mode_reflect)s %(cval)s truncate : float, optional Truncate the filter at this many standard deviations. Default is 4.0. radius : None or int, optional Radius of the Gaussian kernel. If specified, the size of the kernel will be ``2*radius + 1``, and `truncate` is ignored. Default is None. Returns ------- gaussian_filter1d : ndarray Notes ----- The Gaussian kernel will have size ``2*radius + 1`` along each axis. If `radius` is None, a default ``radius = round(truncate * sigma)`` will be used. Examples -------- >>> from scipy.ndimage import gaussian_filter1d >>> import numpy as np >>> gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 1) array([ 1.42704095, 2.06782203, 3. , 3.93217797, 4.57295905]) >>> _gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 4) array([ 2.91948343, 2.95023502, 3. , 3.04976498, 3.08051657]) >>> import matplotlib.pyplot as plt >>> rng = np.random.default_rng() >>> x = rng.standard_normal(101).cumsum() >>> y3 = _gaussian_filter1d(x, 3) >>> y6 = _gaussian_filter1d(x, 6) >>> plt.plot(x, 'k', label='original data') >>> plt.plot(y3, '--', label='filtered, sigma=3') >>> plt.plot(y6, ':', label='filtered, sigma=6') >>> plt.legend() >>> plt.grid() >>> plt.show() """ sd = float(sigma) # make the radius of the filter equal to truncate standard deviations lw = int(truncate * sd + 0.5) if radius is not None: lw = radius if not isinstance(lw, numbers.Integral) or lw < 0: raise ValueError("Radius must be a nonnegative integer.") # Since we are calling correlate, not convolve, revert the kernel weights = _gaussian_kernel1d(sigma, order, lw)[::-1] return _correlate1d(input, weights, axis, output, mode, cval, 0) def _gaussian_kernel1d(sigma, order, radius): """ Computes a 1-D Gaussian convolution kernel. """ if order < 0: raise ValueError("order must be non-negative") exponent_range = np.arange(order + 1) sigma2 = sigma * sigma x = np.arange(-radius, radius + 1) phi_x = np.exp(-0.5 / sigma2 * x**2) phi_x = phi_x / phi_x.sum() if order == 0: return phi_x else: # f(x) = q(x) * phi(x) = q(x) * exp(p(x)) # f'(x) = (q'(x) + q(x) * p'(x)) * phi(x) # p'(x) = -1 / sigma ** 2 # Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the # coefficients of q(x) q = np.zeros(order + 1) q[0] = 1 D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x) P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x) Q_deriv = D + P for _ in range(order): q = Q_deriv.dot(q) q = (x[:, None] ** exponent_range).dot(q) return q * phi_x def _correlate1d(input, weights, axis=-1, output=None, mode="reflect", cval=0.0, origin=0): """Calculate a 1-D correlation along the given axis. The lines of the array along the given axis are correlated with the given weights. Parameters ---------- %(input)s weights : array 1-D sequence of numbers. %(axis)s %(output)s %(mode_reflect)s %(cval)s %(origin)s Returns ------- result : ndarray Correlation result. Has the same shape as `input`. Examples -------- >>> from scipy.ndimage import correlate1d >>> correlate1d([2, 8, 0, 4, 1, 9, 9, 0], weights=[1, 3]) array([ 8, 26, 8, 12, 7, 28, 36, 9]) """ input = np.asarray(input) weights = np.asarray(weights) complex_input = input.dtype.kind == "c" complex_weights = weights.dtype.kind == "c" if complex_input or complex_weights: if complex_weights: weights = weights.conj() weights = weights.astype(np.complex128, copy=False) kwargs = dict(axis=axis, mode=mode, origin=origin) output = _ni_support._get_output(output, input, complex_output=True) return _complex_via_real_components(_correlate1d, input, weights, output, cval, **kwargs) output = _ni_support._get_output(output, input) weights = np.asarray(weights, dtype=np.float64) if weights.ndim != 1 or weights.shape[0] < 1: raise RuntimeError("no filter weights given") if not weights.flags.contiguous: weights = weights.copy() axis = _normalize_axis_index(axis, input.ndim) if _invalid_origin(origin, len(weights)): raise ValueError( "Invalid origin; origin must satisfy " "-(len(weights) // 2) <= origin <= " "(len(weights)-1) // 2" ) mode = _ni_support._extend_mode_to_code(mode) _nd_image.correlate1d(input, weights, axis, output, mode, cval, origin) return output def _complex_via_real_components(func, input, weights, output, cval, **kwargs): """Complex convolution via a linear combination of real convolutions.""" complex_input = input.dtype.kind == "c" complex_weights = weights.dtype.kind == "c" if complex_input and complex_weights: # real component of the output func(input.real, weights.real, output=output.real, cval=np.real(cval), **kwargs) output.real -= func(input.imag, weights.imag, output=None, cval=np.imag(cval), **kwargs) # imaginary component of the output func(input.real, weights.imag, output=output.imag, cval=np.real(cval), **kwargs) output.imag += func(input.imag, weights.real, output=None, cval=np.imag(cval), **kwargs) elif complex_input: func(input.real, weights, output=output.real, cval=np.real(cval), **kwargs) func(input.imag, weights, output=output.imag, cval=np.imag(cval), **kwargs) else: if np.iscomplexobj(cval): raise ValueError("Cannot provide a complex-valued cval when the input is real.") func(input, weights.real, output=output.real, cval=cval, **kwargs) func(input, weights.imag, output=output.imag, cval=cval, **kwargs) return output def _normalize_axis_index(axis, ndim): # Check if `axis` is in the correct range and normalize it if axis < -ndim or axis >= ndim: msg = f"axis {axis} is out of bounds for array of dimension {ndim}" raise AxisError(msg) if axis < 0: axis = axis + ndim return axis def _compute_persistence( sspikes, dim=6, num_times=5, active_times=15000, k=1000, n_points=1200, metric="cosine", nbs=800, maxdim=1, coeff=47, progress_bar=True, ): # Time point downsampling times_cube = np.arange(0, sspikes.shape[0], num_times) # Select most active time points movetimes = np.sort(np.argsort(np.sum(sspikes[times_cube, :], 1))[-active_times:]) movetimes = times_cube[movetimes] # PCA dimensionality reduction scaled_data = preprocessing.scale(sspikes[movetimes, :]) dimred, *_ = _pca(scaled_data, dim=dim) # Point cloud sampling (denoising) indstemp, *_ = _sample_denoising(dimred, k, n_points, 1, metric) # Build distance matrix d = _second_build(dimred, indstemp, metric=metric, nbs=nbs) np.fill_diagonal(d, 0) # Compute persistent homology persistence = ripser( d, maxdim=maxdim, coeff=coeff, do_cocycles=True, distance_matrix=True, progress_bar=progress_bar, ) return persistence def _pca(data, dim=2): """ Perform PCA (Principal Component Analysis) for dimensionality reduction. Parameters: data (ndarray): Input data matrix of shape (N_samples, N_features). dim (int): Target dimension for PCA projection. Returns: components (ndarray): Projected data of shape (N_samples, dim). var_exp (list): Variance explained by each principal component. evals (ndarray): Eigenvalues corresponding to the selected components. """ if dim < 2: return data, [0] m, n = data.shape # mean center the data # data -= data.mean(axis=0) # calculate the covariance matrix R = np.cov(data, rowvar=False) # calculate eigenvectors & eigenvalues of the covariance matrix # use 'eigh' rather than 'eig' since R is symmetric, # the performance gain is substantial evals, evecs = np.linalg.eig(R) # sort eigenvalue in decreasing order idx = np.argsort(evals)[::-1] evecs = evecs[:, idx] # sort eigenvectors according to same index evals = evals[idx] # select the first n eigenvectors (n is desired dimension # of rescaled data array, or dims_rescaled_data) evecs = evecs[:, :dim] # carry out the transformation on the data using eigenvectors # and return the re-scaled data, eigenvalues, and eigenvectors tot = np.sum(evals) var_exp = [(i / tot) * 100 for i in sorted(evals[:dim], reverse=True)] components = np.dot(evecs.T, data.T).T return components, var_exp, evals[:dim] def _sample_denoising(data, k=10, num_sample=500, omega=0.2, metric="euclidean"): """ Perform denoising and greedy sampling based on mutual k-NN graph. Parameters: data (ndarray): High-dimensional point cloud data. k (int): Number of neighbors for local density estimation. num_sample (int): Number of samples to retain. omega (float): Suppression factor during greedy sampling. metric (str): Distance metric used for kNN ('euclidean', 'cosine', etc). Returns: inds (ndarray): Indices of sampled points. d (ndarray): Pairwise similarity matrix of sampled points. Fs (ndarray): Sampling scores at each step. """ if HAS_NUMBA: return _sample_denoising_numba(data, k, num_sample, omega, metric) else: return _sample_denoising_numpy(data, k, num_sample, omega, metric) def _sample_denoising_numpy(data, k=10, num_sample=500, omega=0.2, metric="euclidean"): """Original numpy implementation for fallback.""" n = data.shape[0] X = squareform(pdist(data, metric)) knn_indices = np.argsort(X)[:, :k] knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy() sigmas, rhos = _smooth_knn_dist(knn_dists, k, local_connectivity=0) rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos) result = coo_matrix((vals, (rows, cols)), shape=(n, n)) result.eliminate_zeros() transpose = result.transpose() prod_matrix = result.multiply(transpose) result = result + transpose - prod_matrix result.eliminate_zeros() X = result.toarray() F = np.sum(X, 1) Fs = np.zeros(num_sample) Fs[0] = np.max(F) i = np.argmax(F) inds_all = np.arange(n) inds_left = inds_all > -1 inds_left[i] = False inds = np.zeros(num_sample, dtype=int) inds[0] = i for j in np.arange(1, num_sample): F -= omega * X[i, :] Fmax = np.argmax(F[inds_left]) # Exactly match external TDAvis implementation (including the indexing logic) Fs[j] = F[Fmax] i = inds_all[inds_left][Fmax] inds_left[i] = False inds[j] = i d = np.zeros((num_sample, num_sample)) for j, i in enumerate(inds): d[j, :] = X[i, inds] return inds, d, Fs def _sample_denoising_numba(data, k=10, num_sample=500, omega=0.2, metric="euclidean"): """Optimized numba implementation.""" n = data.shape[0] X = squareform(pdist(data, metric)) knn_indices = np.argsort(X)[:, :k] knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy() sigmas, rhos = _smooth_knn_dist(knn_dists, k, local_connectivity=0) rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos) # Build symmetric adjacency matrix using optimized function X_adj = _build_adjacency_matrix_numba(rows, cols, vals, n) # Greedy sampling using optimized function inds, Fs = _greedy_sampling_numba(X_adj, num_sample, omega) # Build final distance matrix d = _build_distance_matrix_numba(X_adj, inds) return inds, d, Fs @njit(fastmath=True) def _build_adjacency_matrix_numba(rows, cols, vals, n): """Build symmetric adjacency matrix efficiently with numba. This matches the scipy sparse matrix operations: result = result + transpose - prod_matrix where prod_matrix = result.multiply(transpose) """ # Initialize matrices X = np.zeros((n, n), dtype=np.float64) X_T = np.zeros((n, n), dtype=np.float64) # Build adjacency matrix and its transpose simultaneously for i in range(len(rows)): X[rows[i], cols[i]] = vals[i] X_T[cols[i], rows[i]] = vals[i] # Transpose # Apply the symmetrization formula: A = A + A^T - A ⊙ A^T (vectorized) # This matches scipy's: result + transpose - prod_matrix X[:, :] = X + X_T - X * X_T return X @njit(fastmath=True) def _greedy_sampling_numba(X, num_sample, omega): """Optimized greedy sampling with numba.""" n = X.shape[0] F = np.sum(X, axis=1) Fs = np.zeros(num_sample) inds = np.zeros(num_sample, dtype=np.int64) inds_left = np.ones(n, dtype=np.bool_) # Initialize with maximum F i = np.argmax(F) Fs[0] = F[i] inds[0] = i inds_left[i] = False # Greedy sampling loop for j in range(1, num_sample): # Update F values for k in range(n): F[k] -= omega * X[i, k] # Find maximum among remaining points (matching numpy logic exactly) max_val = -np.inf max_idx = -1 for k in range(n): if inds_left[k] and F[k] > max_val: max_val = F[k] max_idx = k # Record the F value using the selected index (matching external TDAvis) i = max_idx Fs[j] = F[i] inds[j] = i inds_left[i] = False return inds, Fs @njit(fastmath=True) def _build_distance_matrix_numba(X, inds): """Build final distance matrix efficiently with numba.""" num_sample = len(inds) d = np.zeros((num_sample, num_sample)) for j in range(num_sample): for k in range(num_sample): d[j, k] = X[inds[j], inds[k]] return d @njit(fastmath=True) def _smooth_knn_dist(distances, k, n_iter=64, local_connectivity=0.0, bandwidth=1.0): """ Compute smoothed local distances for kNN graph with entropy balancing. Parameters: distances (ndarray): kNN distance matrix. k (int): Number of neighbors. n_iter (int): Number of binary search iterations. local_connectivity (float): Minimum local connectivity. bandwidth (float): Bandwidth parameter. Returns: sigmas (ndarray): Smoothed sigma values for each point. rhos (ndarray): Minimum distances (connectivity cutoff) for each point. """ target = np.log2(k) * bandwidth # target = np.log(k) * bandwidth # target = k rho = np.zeros(distances.shape[0]) result = np.zeros(distances.shape[0]) mean_distances = np.mean(distances) for i in range(distances.shape[0]): lo = 0.0 hi = np.inf mid = 1.0 # Vectorized computation of non-zero distances ith_distances = distances[i] non_zero_dists = ith_distances[ith_distances > 0.0] if non_zero_dists.shape[0] >= local_connectivity: index = int(np.floor(local_connectivity)) interpolation = local_connectivity - index if index > 0: rho[i] = non_zero_dists[index - 1] if interpolation > 1e-5: rho[i] += interpolation * (non_zero_dists[index] - non_zero_dists[index - 1]) else: rho[i] = interpolation * non_zero_dists[0] elif non_zero_dists.shape[0] > 0: rho[i] = np.max(non_zero_dists) # Vectorized binary search loop - compute all at once instead of loop for _ in range(n_iter): # Vectorized computation: compute all distances at once d_array = distances[i, 1:] - rho[i] # Vectorized conditional: use np.where for conditional computation psum = np.sum(np.where(d_array > 0, np.exp(-(d_array / mid)), 1.0)) if np.fabs(psum - target) < 1e-5: break if psum > target: hi = mid mid = (lo + hi) / 2.0 else: lo = mid if hi == np.inf: mid *= 2 else: mid = (lo + hi) / 2.0 result[i] = mid # Optimized mean computation - reuse ith_distances if rho[i] > 0.0: mean_ith_distances = np.mean(ith_distances) if result[i] < 1e-3 * mean_ith_distances: result[i] = 1e-3 * mean_ith_distances else: if result[i] < 1e-3 * mean_distances: result[i] = 1e-3 * mean_distances return result, rho @njit(parallel=True, fastmath=True) def _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos): """ Compute membership strength matrix from smoothed kNN graph. Parameters: knn_indices (ndarray): Indices of k-nearest neighbors. knn_dists (ndarray): Corresponding distances. sigmas (ndarray): Local bandwidths. rhos (ndarray): Minimum distance thresholds. Returns: rows (ndarray): Row indices for sparse matrix. cols (ndarray): Column indices for sparse matrix. vals (ndarray): Weight values for sparse matrix. """ n_samples = knn_indices.shape[0] n_neighbors = knn_indices.shape[1] rows = np.zeros((n_samples * n_neighbors), dtype=np.int64) cols = np.zeros((n_samples * n_neighbors), dtype=np.int64) vals = np.zeros((n_samples * n_neighbors), dtype=np.float64) for i in range(n_samples): for j in range(n_neighbors): if knn_indices[i, j] == -1: continue # We didn't get the full knn for i if knn_indices[i, j] == i: val = 0.0 elif knn_dists[i, j] - rhos[i] <= 0.0: val = 1.0 else: val = np.exp(-((knn_dists[i, j] - rhos[i]) / (sigmas[i]))) # val = ((knn_dists[i, j] - rhos[i]) / (sigmas[i])) rows[i * n_neighbors + j] = i cols[i * n_neighbors + j] = knn_indices[i, j] vals[i * n_neighbors + j] = val return rows, cols, vals def _second_build(data, indstemp, nbs=800, metric="cosine"): """ Reconstruct distance matrix after denoising for persistent homology. Parameters: data (ndarray): PCA-reduced data matrix. indstemp (ndarray): Indices of sampled points. nbs (int): Number of neighbors in reconstructed graph. metric (str): Distance metric ('cosine', 'euclidean', etc). Returns: d (ndarray): Symmetric distance matrix used for persistent homology. """ # Filter the data using the sampled point indices data = data[indstemp, :] # Compute the pairwise distance matrix X = squareform(pdist(data, metric)) knn_indices = np.argsort(X)[:, :nbs] knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy() # Compute smoothed kernel widths sigmas, rhos = _smooth_knn_dist(knn_dists, nbs, local_connectivity=0) rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos) # Construct a sparse graph result = coo_matrix((vals, (rows, cols)), shape=(X.shape[0], X.shape[0])) result.eliminate_zeros() transpose = result.transpose() prod_matrix = result.multiply(transpose) result = result + transpose - prod_matrix result.eliminate_zeros() # Build the final distance matrix d = result.toarray() # Match external TDAvis: direct negative log without epsilon handling # Temporarily suppress divide by zero warning to match external behavior with np.errstate(divide="ignore", invalid="ignore"): d = -np.log(d) np.fill_diagonal(d, 0) return d def _run_shuffle_analysis(sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs): """Perform shuffle analysis with optimized computation.""" return _run_shuffle_analysis_multiprocessing( sspikes, num_shuffles, num_cores, progress_bar, **kwargs ) def _run_shuffle_analysis_multiprocessing( sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs ): """Original multiprocessing implementation for fallback.""" # Use numpy arrays with NaN for failed results (more efficient than None filtering) max_lifetimes = { 0: np.full(num_shuffles, np.nan), 1: np.full(num_shuffles, np.nan), 2: np.full(num_shuffles, np.nan), } # Estimate runtime with a test iteration logging.info("Running test iteration to estimate runtime...") _ = _process_single_shuffle((0, sspikes, kwargs)) # Prepare task list tasks = [(i, sspikes, kwargs) for i in range(num_shuffles)] logging.info( f"Starting shuffle analysis with {num_shuffles} iterations using {num_cores} cores..." ) # Use multiprocessing pool for parallel processing with mp.Pool(processes=num_cores) as pool: results = list(pool.imap(_process_single_shuffle, tasks)) logging.info("Shuffle analysis completed") # Collect results - use indexing instead of append for better performance for idx, res in enumerate(results): for dim, lifetime in res.items(): max_lifetimes[dim][idx] = lifetime # Filter out NaN values (failed results) - convert to list for consistency for dim in max_lifetimes: max_lifetimes[dim] = max_lifetimes[dim][~np.isnan(max_lifetimes[dim])].tolist() return max_lifetimes @njit(fastmath=True) def _fast_pca_transform(data, components): """Fast PCA transformation using numba.""" return np.dot(data, components.T) def _process_single_shuffle(args): """Process a single shuffle task.""" i, sspikes, kwargs = args try: shuffled_data = _shuffle_spike_trains(sspikes) persistence = _compute_persistence(shuffled_data, **kwargs) dim_max_lifetimes = {} for dim in [0, 1, 2]: if dim < len(persistence["dgms"]): # Filter out infinite values valid_bars = [bar for bar in persistence["dgms"][dim] if not np.isinf(bar[1])] if valid_bars: lifetimes = [bar[1] - bar[0] for bar in valid_bars] if lifetimes: dim_max_lifetimes[dim] = max(lifetimes) return dim_max_lifetimes except Exception as e: print(f"Shuffle {i} failed: {str(e)}") return {} def _shuffle_spike_trains(sspikes): """Perform random circular shift on spike trains.""" shuffled = sspikes.copy() num_neurons = shuffled.shape[1] # Independent shift for each neuron for n in range(num_neurons): shift = np.random.randint(0, int(shuffled.shape[0] * 0.1)) shuffled[:, n] = np.roll(shuffled[:, n], shift) return shuffled def _plot_barcode(persistence): """ Plot barcode diagram from persistent homology result. Parameters: persistence (dict): Persistent homology result with 'dgms' key. """ cs = np.repeat([[0, 0.55, 0.2]], 3).reshape(3, 3).T # RGB color for each dimension alpha = 1 inf_delta = 0.1 colormap = cs dgms = persistence["dgms"] maxdim = len(dgms) - 1 dims = np.arange(maxdim + 1) labels = ["$H_0$", "$H_1$", "$H_2$"] # Determine axis range min_birth, max_death = 0, 0 for dim in dims: persistence_dim = dgms[dim][~np.isinf(dgms[dim][:, 1]), :] if persistence_dim.size > 0: min_birth = min(min_birth, np.min(persistence_dim)) max_death = max(max_death, np.max(persistence_dim)) delta = (max_death - min_birth) * inf_delta infinity = max_death + delta axis_start = min_birth - delta # Create plot fig = plt.figure(figsize=(10, 6)) gs = gridspec.GridSpec(len(dims), 1) for dim in dims: axes = plt.subplot(gs[dim]) axes.axis("on") axes.set_yticks([]) axes.set_ylabel(labels[dim], rotation=0, labelpad=20, fontsize=12) d = np.copy(dgms[dim]) d[np.isinf(d[:, 1]), 1] = infinity dlife = d[:, 1] - d[:, 0] # Select top 30 bars by lifetime dinds = np.argsort(dlife)[-30:] if dim > 0: dinds = dinds[np.flip(np.argsort(d[dinds, 0]))] axes.barh( 0.5 + np.arange(len(dinds)), dlife[dinds], height=0.8, left=d[dinds, 0], alpha=alpha, color=colormap[dim], linewidth=0, ) axes.plot([0, 0], [0, len(dinds)], c="k", linestyle="-", lw=1) axes.plot([0, len(dinds)], [0, 0], c="k", linestyle="-", lw=1) axes.set_xlim([axis_start, infinity]) plt.tight_layout() return fig def _plot_barcode_with_shuffle(persistence, shuffle_max): """ Plot barcode with shuffle region markers. """ # Handle case where shuffle_max is None if shuffle_max is None: shuffle_max = {} cs = np.repeat([[0, 0.55, 0.2]], 3).reshape(3, 3).T alpha = 1 inf_delta = 0.1 colormap = cs maxdim = len(persistence["dgms"]) - 1 dims = np.arange(maxdim + 1) min_birth, max_death = 0, 0 for dim in dims: # Filter out infinite values valid_bars = [bar for bar in persistence["dgms"][dim] if not np.isinf(bar[1])] if valid_bars: min_birth = min(min_birth, np.min(valid_bars)) max_death = max(max_death, np.max(valid_bars)) # Handle case with no valid bars if max_death == 0 and min_birth == 0: min_birth = 0 max_death = 1 delta = (max_death - min_birth) * inf_delta infinity = max_death + delta # Create figure fig = plt.figure(figsize=(10, 8)) gs = gridspec.GridSpec(len(dims), 1) # Get shuffle thresholds (99.9th percentile for each dimension) thresholds = {} for dim in dims: if dim in shuffle_max and shuffle_max[dim]: thresholds[dim] = np.percentile(shuffle_max[dim], 99.9) else: thresholds[dim] = 0 for _, dim in enumerate(dims): axes = plt.subplot(gs[dim]) axes.axis("off") # Add gray background to represent shuffle region if dim in thresholds: axes.axvspan(0, thresholds[dim], alpha=0.2, color="gray", zorder=-3) axes.axvline(x=thresholds[dim], color="gray", linestyle="--", alpha=0.7) # Do not pre-filter out infinite bars; copy the full diagram instead d = np.copy(persistence["dgms"][dim]) if d.size == 0: d = np.zeros((0, 2)) # Map infinite death values to a finite upper bound for visualization d[np.isinf(d[:, 1]), 1] = infinity dlife = d[:, 1] - d[:, 0] # Select top 30 longest-lived bars if len(dlife) > 0: dinds = np.argsort(dlife)[-30:] if dim > 0: dinds = dinds[np.flip(np.argsort(d[dinds, 0]))] # Mark significant bars significant_bars = [] for idx in dinds: if dlife[idx] > thresholds.get(dim, 0): significant_bars.append(idx) # Draw bars for i, idx in enumerate(dinds): color = "red" if idx in significant_bars else colormap[dim] axes.barh( 0.5 + i, dlife[idx], height=0.8, left=d[idx, 0], alpha=alpha, color=color, linewidth=0, ) indsall = len(dinds) else: indsall = 0 axes.plot([0, 0], [0, indsall], c="k", linestyle="-", lw=1) axes.plot([0, indsall], [0, 0], c="k", linestyle="-", lw=1) axes.set_xlim([0, infinity]) axes.set_title(f"$H_{dim}$", loc="left") plt.tight_layout() return fig
[docs] def decode_circular_coordinates( persistence_result: dict[str, Any], spike_data: dict[str, Any], real_ground: bool = True, real_of: bool = True, save_path: str | None = None, ) -> dict[str, Any]: """ Decode circular coordinates (bump positions) from cohomology. Parameters: persistence_result : dict containing persistence analysis results with keys: - 'persistence': persistent homology result - 'indstemp': indices of sampled points - 'movetimes': selected time points - 'n_points': number of sampled points spike_data : dict, optional Spike data dictionary containing 'spike', 't', and optionally 'x', 'y' real_ground : bool Whether x, y, t ground truth exists real_of : bool Whether experiment was performed in open field save_path : str, optional Path to save decoding results. If None, saves to 'Results/spikes_decoding.npz' Returns: dict : Dictionary containing decoding results with keys: - 'coords': decoded coordinates for all timepoints - 'coordsbox': decoded coordinates for box timepoints - 'times': time indices for coords - 'times_box': time indices for coordsbox - 'centcosall': cosine centroids - 'centsinall': sine centroids """ ph_classes = [0, 1] # Decode the ith most persistent cohomology class num_circ = len(ph_classes) dec_tresh = 0.99 coeff = 47 # Extract persistence analysis results persistence = persistence_result["persistence"] indstemp = persistence_result["indstemp"] movetimes = persistence_result["movetimes"] n_points = persistence_result["n_points"] diagrams = persistence["dgms"] # the multiset describing the lives of the persistence classes cocycles = persistence["cocycles"][1] # the cocycle representatives for the 1-dim classes dists_land = persistence["dperm2all"] # the pairwise distance between the points births1 = diagrams[1][:, 0] # the time of birth for the 1-dim classes deaths1 = diagrams[1][:, 1] # the time of death for the 1-dim classes deaths1[np.isinf(deaths1)] = 0 lives1 = deaths1 - births1 # the lifetime for the 1-dim classes iMax = np.argsort(lives1) coords1 = np.zeros((num_circ, len(indstemp))) threshold = births1[iMax[-2]] + (deaths1[iMax[-2]] - births1[iMax[-2]]) * dec_tresh for c in ph_classes: cocycle = cocycles[iMax[-(c + 1)]] coords1[c, :], inds = _get_coords(cocycle, threshold, len(indstemp), dists_land, coeff) if real_ground: # 用户所提供的数据是否有真实的xyt sspikes, xx, yy, tt = embed_spike_trains( spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True) ) else: sspikes = embed_spike_trains( spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=False) ) num_neurons = sspikes.shape[1] centcosall = np.zeros((num_neurons, 2, n_points)) centsinall = np.zeros((num_neurons, 2, n_points)) dspk = preprocessing.scale(sspikes[movetimes[indstemp], :]) for neurid in range(num_neurons): spktemp = dspk[:, neurid].copy() centcosall[neurid, :, :] = np.multiply(np.cos(coords1[:, :] * 2 * np.pi), spktemp) centsinall[neurid, :, :] = np.multiply(np.sin(coords1[:, :] * 2 * np.pi), spktemp) if real_ground: # 用户所提供的数据是否有真实的xyt sspikes, xx, yy, tt = embed_spike_trains( spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True) ) spikes, __, __, __ = embed_spike_trains( spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True) ) else: sspikes = embed_spike_trains( spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=False) ) spikes = embed_spike_trains( spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=False) ) times = np.where(np.sum(spikes > 0, 1) >= 1)[0] dspk = preprocessing.scale(sspikes) sspikes = sspikes[times, :] dspk = dspk[times, :] a = np.zeros((len(sspikes[:, 0]), 2, num_neurons)) for n in range(num_neurons): a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1)) c = np.zeros((len(sspikes[:, 0]), 2, num_neurons)) for n in range(num_neurons): c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1)) mtot2 = np.sum(c, 2) mtot1 = np.sum(a, 2) coords = np.arctan2(mtot2, mtot1) % (2 * np.pi) if real_of: # 用户的数据是否是来自真实的OF场地 coordsbox = coords.copy() times_box = times.copy() else: sspikes, xx, yy, tt = embed_spike_trains( spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True) ) spikes, __, __, __ = embed_spike_trains( spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True) ) dspk = preprocessing.scale(sspikes) times_box = np.where(np.sum(spikes > 0, 1) >= 1)[0] dspk = dspk[times_box, :] a = np.zeros((len(times_box), 2, num_neurons)) for n in range(num_neurons): a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1)) c = np.zeros((len(times_box), 2, num_neurons)) for n in range(num_neurons): c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1)) mtot2 = np.sum(c, 2) mtot1 = np.sum(a, 2) coordsbox = np.arctan2(mtot2, mtot1) % (2 * np.pi) # Prepare results dictionary results = { "coords": coords, "coordsbox": coordsbox, "times": times, "times_box": times_box, "centcosall": centcosall, "centsinall": centsinall, } # Save results if save_path is None: os.makedirs("Results", exist_ok=True) save_path = "Results/spikes_decoding.npz" os.makedirs(os.path.dirname(save_path), exist_ok=True) np.savez_compressed(save_path, **results) return results
[docs] def plot_cohomap( decoding_result: dict[str, Any], position_data: dict[str, Any], save_path: str | None = None, show: bool = False, figsize: tuple[int, int] = (10, 4), dpi: int = 300, subsample: int = 10, ) -> plt.Figure: """ Visualize CohoMap 1.0: decoded circular coordinates mapped onto spatial trajectory. Creates a two-panel visualization showing how the two decoded circular coordinates vary across the animal's spatial trajectory. Each panel displays the spatial path colored by the cosine of one circular coordinate dimension. Parameters: decoding_result : dict Dictionary from decode_circular_coordinates() containing: - 'coordsbox': decoded coordinates for box timepoints (n_times x n_dims) - 'times_box': time indices for coordsbox position_data : dict Position data containing 'x' and 'y' arrays for spatial coordinates save_path : str, optional Path to save the visualization. If None, no save performed show : bool, default=False Whether to display the visualization figsize : tuple[int, int], default=(10, 4) Figure size (width, height) in inches dpi : int, default=300 Resolution for saved figure subsample : int, default=10 Subsampling interval for plotting (plot every Nth timepoint) Returns: plt.Figure : The matplotlib figure object Raises: KeyError : If required keys are missing from input dictionaries ValueError : If data dimensions are inconsistent IndexError : If time indices are out of bounds Examples: >>> # Decode coordinates >>> decoding = decode_circular_coordinates(persistence_result, spike_data) >>> # Visualize with trajectory data >>> fig = plot_cohomap( ... decoding, ... position_data={'x': xx, 'y': yy}, ... save_path='cohomap.png', ... show=True ... ) """ try: # Extract data coordsbox = decoding_result["coordsbox"] times_box = decoding_result["times_box"] xx = position_data["x"] yy = position_data["y"] # Subsample time indices for plotting plot_times = np.arange(0, len(coordsbox), subsample) # Create a two-panel figure (one per cohomology dimension) plt.set_cmap("viridis") fig, ax = plt.subplots(1, 2, figsize=figsize) # Plot for the first circular coordinate ax[0].axis("off") ax[0].set_aspect("equal", "box") im0 = ax[0].scatter( xx[times_box][plot_times], yy[times_box][plot_times], c=np.cos(coordsbox[plot_times, 0]), s=8, cmap="viridis", ) plt.colorbar(im0, ax=ax[0], label="cos(coord)") ax[0].set_title("CohoMap Dim 1", fontsize=10) # Plot for the second circular coordinate ax[1].axis("off") ax[1].set_aspect("equal", "box") im1 = ax[1].scatter( xx[times_box][plot_times], yy[times_box][plot_times], c=np.cos(coordsbox[plot_times, 1]), s=8, cmap="viridis", ) plt.colorbar(im1, ax=ax[1], label="cos(coord)") ax[1].set_title("CohoMap Dim 2", fontsize=10) plt.tight_layout() # Save if path provided if save_path: try: os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path, dpi=dpi) print(f"CohoMap visualization saved to {save_path}") except Exception as e: print(f"Error saving CohoMap visualization: {e}") # Show if requested if show: plt.show() else: plt.close(fig) return fig except (KeyError, ValueError, IndexError) as e: print(f"CohoMap visualization failed: {e}") raise except Exception as e: print(f"Unexpected error in CohoMap visualization: {e}") raise
[docs] def plot_3d_bump_on_torus( decoding_result: dict[str, Any] | str, spike_data: dict[str, Any], config: CANN2DPlotConfig | None = None, save_path: str | None = None, numangsint: int = 51, r1: float = 1.5, r2: float = 1.0, window_size: int = 300, frame_step: int = 5, n_frames: int = 20, fps: int = 5, show_progress: bool = True, show: bool = True, figsize: tuple[int, int] = (8, 8), **kwargs, ) -> animation.FuncAnimation: """ Visualize the movement of the neural activity bump on a torus using matplotlib animation. This function follows the canns.analyzer.plotting patterns for animation generation with progress tracking and proper resource cleanup. Parameters: decoding_result : dict or str Dictionary containing decoding results with 'coordsbox' and 'times_box' keys, or path to .npz file containing these results spike_data : dict, optional Spike data dictionary containing spike information config : PlotConfig, optional Configuration object for unified plotting parameters **kwargs : backward compatibility parameters save_path : str, optional Path to save the animation (e.g., 'animation.gif' or 'animation.mp4') numangsint : int Grid resolution for the torus surface r1 : float Major radius of the torus r2 : float Minor radius of the torus window_size : int Time window (in number of time points) for each frame frame_step : int Step size to slide the time window between frames n_frames : int Total number of frames in the animation fps : int Frames per second for the output animation show_progress : bool Whether to show progress bar during generation show : bool Whether to display the animation figsize : tuple[int, int] Figure size for the animation Returns: matplotlib.animation.FuncAnimation : The animation object """ # Handle backward compatibility and configuration if config is None: config = CANN2DPlotConfig.for_torus_animation(**kwargs) # Override config with any explicitly passed parameters for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) # Extract configuration values save_path = config.save_path if config.save_path else save_path show = config.show figsize = config.figsize fps = config.fps show_progress = config.show_progress_bar numangsint = config.numangsint r1 = config.r1 r2 = config.r2 window_size = config.window_size frame_step = config.frame_step n_frames = config.n_frames # Load decoding results if path is provided if isinstance(decoding_result, str): f = np.load(decoding_result, allow_pickle=True) coords = f["coordsbox"] times = f["times_box"] f.close() else: coords = decoding_result["coordsbox"] times = decoding_result["times_box"] spk, *_ = embed_spike_trains( spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True) ) # Prepare animation data frame_data = [] prev_m = None for frame_idx in tqdm(range(n_frames), desc="Processing frames"): start_idx = frame_idx * frame_step end_idx = start_idx + window_size if end_idx > np.max(times): break mask = (times >= start_idx) & (times < end_idx) coords_window = coords[mask] if len(coords_window) == 0: continue spk_window = spk[times[mask], :] activity = np.sum(spk_window, axis=1) m, x_edge, y_edge, _ = binned_statistic_2d( coords_window[:, 0], coords_window[:, 1], activity, statistic="sum", bins=np.linspace(0, 2 * np.pi, numangsint - 1), ) m = np.nan_to_num(m) m = _smooth_tuning_map(m, numangsint - 1, sig=4.0, bClose=True) m = gaussian_filter(m, sigma=1.0) if prev_m is not None: m = 0.7 * prev_m + 0.3 * m prev_m = m # Store processed data for animation X, Y = np.meshgrid(x_edge, y_edge) X = (X + np.pi / 5) % (2 * np.pi) x = (r1 + r2 * np.cos(X)) * np.cos(Y) y = (r1 + r2 * np.cos(X)) * np.sin(Y) z = -r2 * np.sin(X) # Flip torus surface orientation frame_data.append({"x": x, "y": y, "z": z, "m": m, "time": start_idx * frame_step}) if not frame_data: raise ProcessingError("No valid frames generated for animation") # Create figure and animation following plotting package pattern fig = plt.figure(figsize=figsize) try: ax = fig.add_subplot(111, projection="3d") ax.set_zlim(-2, 2) ax.view_init(-125, 135) ax.axis("off") # Initialize with first frame first_frame = frame_data[0] surface = ax.plot_surface( first_frame["x"], first_frame["y"], first_frame["z"], facecolors=cm.viridis(first_frame["m"] / (np.max(first_frame["m"]) + 1e-9)), alpha=1, linewidth=0.1, antialiased=True, rstride=1, cstride=1, shade=False, ) def animate(frame_idx): """Animation update function following plotting package pattern.""" if frame_idx >= len(frame_data): return (surface,) frame = frame_data[frame_idx] # Clear and redraw surface ax.clear() ax.set_zlim(-2, 2) ax.view_init(-125, 135) ax.axis("off") new_surface = ax.plot_surface( frame["x"], frame["y"], frame["z"], facecolors=cm.viridis(frame["m"] / (np.max(frame["m"]) + 1e-9)), alpha=1, linewidth=0.1, antialiased=True, rstride=1, cstride=1, shade=False, ) # Update time text time_text = ax.text2D( 0.05, 0.95, f"Frame: {frame_idx + 1}/{len(frame_data)}", transform=ax.transAxes, fontsize=12, bbox=dict(facecolor="white", alpha=0.7), ) return new_surface, time_text # Create animation interval_ms = 1000 / fps ani = animation.FuncAnimation( fig, animate, frames=len(frame_data), interval=interval_ms, blit=False, repeat=True ) # Save animation if path provided if save_path: if show_progress: pbar = tqdm(total=len(frame_data), desc=f"Saving animation to {save_path}") def progress_callback(current_frame, total_frames): pbar.update(1) try: writer = animation.PillowWriter(fps=fps) ani.save(save_path, writer=writer, progress_callback=progress_callback) pbar.close() print(f"\nAnimation saved to: {save_path}") except Exception as e: pbar.close() print(f"\nError saving animation: {e}") else: try: writer = animation.PillowWriter(fps=fps) ani.save(save_path, writer=writer) print(f"Animation saved to: {save_path}") except Exception as e: print(f"Error saving animation: {e}") if show: # Automatically detect Jupyter and display as HTML/JS if is_jupyter_environment(): display_animation_in_jupyter(ani) plt.close(fig) # Close after HTML conversion to prevent auto-display else: plt.show() else: plt.close(fig) # Close if not showing # Return None in Jupyter when showing to avoid double display if show and is_jupyter_environment(): return None return ani except Exception as e: plt.close(fig) raise ProcessingError(f"Failed to create torus animation: {e}") from e
def _get_coords(cocycle, threshold, num_sampled, dists, coeff): """ Reconstruct circular coordinates from cocycle information. Parameters: cocycle (ndarray): Persistent cocycle representative. threshold (float): Maximum allowable edge distance. num_sampled (int): Number of sampled points. dists (ndarray): Pairwise distance matrix. coeff (int): Finite field modulus for cohomology. Returns: f (ndarray): Circular coordinate values (in [0,1]). verts (ndarray): Indices of used vertices. """ zint = np.where(coeff - cocycle[:, 2] < cocycle[:, 2]) cocycle[zint, 2] = cocycle[zint, 2] - coeff d = np.zeros((num_sampled, num_sampled)) d[np.tril_indices(num_sampled)] = np.nan d[cocycle[:, 1], cocycle[:, 0]] = cocycle[:, 2] d[dists > threshold] = np.nan d[dists == 0] = np.nan edges = np.where(~np.isnan(d)) verts = np.array(np.unique(edges)) num_edges = np.shape(edges)[1] num_verts = np.size(verts) values = d[edges] A = np.zeros((num_edges, num_verts), dtype=int) v1 = np.zeros((num_edges, 2), dtype=int) v2 = np.zeros((num_edges, 2), dtype=int) for i in range(num_edges): # Extract scalar indices from np.where results idx1 = np.where(verts == edges[0][i])[0] idx2 = np.where(verts == edges[1][i])[0] # Handle case where np.where returns multiple matches (shouldn't happen in valid data) if len(idx1) > 0: v1[i, :] = [i, idx1[0]] else: raise ValueError(f"No vertex found for edge {edges[0][i]}") if len(idx2) > 0: v2[i, :] = [i, idx2[0]] else: raise ValueError(f"No vertex found for edge {edges[1][i]}") A[v1[:, 0], v1[:, 1]] = -1 A[v2[:, 0], v2[:, 1]] = 1 L = np.ones((num_edges,)) Aw = A * np.sqrt(L[:, np.newaxis]) Bw = values * np.sqrt(L) f = lsmr(Aw, Bw)[0] % 1 return f, verts def _smooth_tuning_map(mtot, numangsint, sig, bClose=True): """ Smooth activity map over circular topology (e.g., torus). Parameters: mtot (ndarray): Raw activity map matrix. numangsint (int): Grid resolution. sig (float): Smoothing kernel standard deviation. bClose (bool): Whether to assume circular boundary conditions. Returns: mtot_out (ndarray): Smoothed map matrix. """ numangsint_1 = numangsint - 1 mid = int((numangsint_1) / 2) indstemp1 = np.zeros((numangsint_1, numangsint_1), dtype=int) indstemp1[indstemp1 == 0] = np.arange((numangsint_1) ** 2) mid = int((numangsint_1) / 2) mtemp1_3 = mtot.copy() for i in range(numangsint_1): mtemp1_3[i, :] = np.roll(mtemp1_3[i, :], int(i / 2)) mtot_out = np.zeros_like(mtot) mtemp1_4 = np.concatenate((mtemp1_3, mtemp1_3, mtemp1_3), 1) mtemp1_5 = np.zeros_like(mtemp1_4) mtemp1_5[:, :mid] = mtemp1_4[:, (numangsint_1) * 3 - mid :] mtemp1_5[:, mid:] = mtemp1_4[:, : (numangsint_1) * 3 - mid] if bClose: mtemp1_6 = _smooth_image(np.concatenate((mtemp1_5, mtemp1_4, mtemp1_5)), sigma=sig) else: mtemp1_6 = gaussian_filter(np.concatenate((mtemp1_5, mtemp1_4, mtemp1_5)), sigma=sig) for i in range(numangsint_1): mtot_out[i, :] = mtemp1_6[ (numangsint_1) + i, (numangsint_1) + (int(i / 2) + 1) : (numangsint_1) * 2 + (int(i / 2) + 1), ] return mtot_out def _smooth_image(img, sigma): """ Smooth image using multivariate Gaussian kernel, handling missing (NaN) values. Parameters: img (ndarray): Input image matrix. sigma (float): Standard deviation of smoothing kernel. Returns: imgC (ndarray): Smoothed image with inpainting around NaNs. """ filterSize = max(np.shape(img)) grid = np.arange(-filterSize + 1, filterSize, 1) xx, yy = np.meshgrid(grid, grid) pos = np.dstack((xx, yy)) var = multivariate_normal(mean=[0, 0], cov=[[sigma**2, 0], [0, sigma**2]]) k = var.pdf(pos) k = k / np.sum(k) nans = np.isnan(img) imgA = img.copy() imgA[nans] = 0 imgA = signal.convolve2d(imgA, k, mode="valid") imgD = img.copy() imgD[nans] = 0 imgD[~nans] = 1 radius = 1 L = np.arange(-radius, radius + 1) X, Y = np.meshgrid(L, L) dk = np.array((X**2 + Y**2) <= radius**2, dtype=bool) imgE = np.zeros((filterSize + 2, filterSize + 2)) imgE[1:-1, 1:-1] = imgD imgE = binary_closing(imgE, iterations=1, structure=dk) imgD = imgE[1:-1, 1:-1] imgB = np.divide( signal.convolve2d(imgD, k, mode="valid"), signal.convolve2d(np.ones(np.shape(imgD)), k, mode="valid"), ) imgC = np.divide(imgA, imgB) imgC[imgD == 0] = -np.inf return imgC if __name__ == "__main__": from canns.data.loaders import load_grid_data
[docs] data = load_grid_data()
spikes, xx, yy, tt = embed_spike_trains(data) # import umap # # reducer = umap.UMAP( # n_neighbors=15, # min_dist=0.1, # n_components=3, # metric='euclidean', # random_state=42 # ) # # reduce_func = reducer.fit_transform # # plot_projection(reduce_func=reduce_func, embed_data=spikes, show=True) results = tda_vis(embed_data=spikes, maxdim=1, do_shuffle=False, show=True) decoding = decode_circular_coordinates( persistence_result=results, spike_data=data, real_ground=True, real_of=True, ) # Visualize CohoMap plot_cohomap( decoding_result=decoding, position_data={"x": xx, "y": yy}, save_path="Results/cohomap.png", show=True, ) # results = tda_vis(embed_data=spikes, maxdim=1, do_shuffle=True, num_shuffles=10, show=True)