Source code for src.canns.data.loaders

"""
Experimental data processing utilities for CANNs.

This module provides specialized functions for processing experimental data
typically used in CANN analyses, including ROI data, grid cell data, and
other neurophysiological _datasets.
"""

from pathlib import Path
from typing import Any

import numpy as np

from . import datasets as _datasets


[docs] def load_roi_data(source: str | Path | None = None) -> np.ndarray | None: """ Load ROI data for 1D CANN analysis. Parameters ---------- source : str, Path, or None Data source. Can be: - URL string: downloads and loads from URL - Path: loads from local file - None: uses default CANNs dataset Returns ------- ndarray or None ROI data array if successful, None otherwise. Examples -------- >>> # Load default dataset >>> roi_data = load_roi_data() >>> >>> # Load from URL >>> roi_data = load_roi_data('https://example.com/roi_data.txt') >>> >>> # Load from local file >>> roi_data = load_roi_data('./my_roi_data.txt') """ # Handle different source types if source is None: # Use default CANNs dataset dataset_path = _datasets.get_dataset_path("roi_data") if dataset_path is None: return None try: data = np.loadtxt(dataset_path) print(f"Loaded ROI data: shape {data.shape}") return data except Exception as e: print(f"Failed to load ROI data: {e}") return None elif isinstance(source, str) and source.startswith(("http://", "https://")): # Load from URL try: data = _datasets.load(source, file_type="text") if isinstance(data, str): # If loaded as text, try to parse as numpy array lines = data.strip().split("\n") data = np.array([[float(x) for x in line.split()] for line in lines]) print(f"Loaded ROI data from URL: shape {data.shape}") return data except Exception as e: print(f"Failed to load ROI data from URL: {e}") return None else: # Load from local path source_path = Path(source) if not source_path.exists(): print(f"File not found: {source_path}") return None try: data = np.loadtxt(source_path) print(f"Loaded ROI data from {source_path}: shape {data.shape}") return data except Exception as e: print(f"Failed to load ROI data from {source_path}: {e}") return None
[docs] def load_grid_data( source: str | Path | None = None, dataset_key: str = "grid_1" ) -> dict[str, Any] | None: """ Load grid cell data for 2D CANN analysis. Parameters ---------- source : str, Path, or None Data source. Can be: - URL string: downloads and loads from URL - Path: loads from local file - None: uses default CANNs dataset dataset_key : str Which default dataset to use ('grid_1' or 'grid_2') when source is None. Returns ------- dict or None Dictionary containing spike data and metadata if successful, None otherwise. Expected keys: 'spike', 't', and optionally 'x', 'y' for position data. Examples -------- >>> # Load default dataset >>> grid_data = load_grid_data() >>> >>> # Load from URL >>> grid_data = load_grid_data('https://example.com/grid_data.npz') >>> >>> # Load specific default dataset >>> grid_data = load_grid_data(dataset_key='grid_2') """ # Handle different source types if source is None: # Use default CANNs dataset dataset_path = _datasets.get_dataset_path(dataset_key) if dataset_path is None: return None try: data = np.load(dataset_path, allow_pickle=True) result = { "spike": data["spike"], "t": data["t"], } # Add position data if available if "x" in data: result["x"] = data["x"] if "y" in data: result["y"] = data["y"] # Handle different spike data formats if hasattr(result["spike"], "item") and isinstance(result["spike"].item(), dict): # Spike data is stored as a dictionary inside numpy array spike_dict = result["spike"].item() print(f"Loaded {dataset_key}: {len(spike_dict)} neurons") else: print(f"Loaded {dataset_key}: {len(result['spike'])} neurons") if "x" in result: print(f"Position data available: {len(result['x'])} time points") return result except Exception as e: print(f"Failed to load {dataset_key}: {e}") return None elif isinstance(source, str) and source.startswith(("http://", "https://")): # Load from URL try: data = _datasets.load(source, file_type="numpy") if isinstance(data, dict): result = {} if "spike" in data: result["spike"] = data["spike"] if "t" in data: result["t"] = data["t"] if "x" in data: result["x"] = data["x"] if "y" in data: result["y"] = data["y"] print(f"Loaded grid data from URL: {len(result.get('spike', []))} neurons") return result else: print("Grid data must be in .npz format with 'spike' and 't' arrays") return None except Exception as e: print(f"Failed to load grid data from URL: {e}") return None else: # Load from local path source_path = Path(source) if not source_path.exists(): print(f"File not found: {source_path}") return None try: data = np.load(source_path, allow_pickle=True) result = { "spike": data["spike"], "t": data["t"], } # Add position data if available if "x" in data: result["x"] = data["x"] if "y" in data: result["y"] = data["y"] print(f"Loaded grid data from {source_path}: {len(result['spike'])} neurons") if "x" in result: print(f"Position data available: {len(result['x'])} time points") return result except Exception as e: print(f"Failed to load grid data from {source_path}: {e}") return None
[docs] def validate_roi_data(data: np.ndarray) -> bool: """ Validate ROI data format for 1D CANN analysis. Parameters ---------- data : ndarray ROI data array. Returns ------- bool True if data is valid, False otherwise. """ if not isinstance(data, np.ndarray): print("ROI data must be a numpy array") return False if data.ndim not in [1, 2]: print(f"ROI data must be 1D or 2D, got {data.ndim}D") return False if data.size == 0: print("ROI data is empty") return False if not np.isfinite(data).all(): print("ROI data contains non-finite values") return False return True
[docs] def validate_grid_data(data: dict[str, Any]) -> bool: """ Validate grid data format for 2D CANN analysis. Parameters ---------- data : dict Grid data dictionary. Returns ------- bool True if data is valid, False otherwise. """ if not isinstance(data, dict): print("Grid data must be a dictionary") return False # Check required keys required_keys = ["spike", "t"] for key in required_keys: if key not in data: print(f"Grid data missing required key: {key}") return False # Validate spike data spike_data = data["spike"] if not isinstance(spike_data, list | np.ndarray): print("Spike data must be a list or numpy array") return False if len(spike_data) == 0: print("Spike data is empty") return False # Validate time data t_data = data["t"] if not isinstance(t_data, np.ndarray): print("Time data must be a numpy array") return False if t_data.size == 0: print("Time data is empty") return False if not np.isfinite(t_data).all(): print("Time data contains non-finite values") return False # Validate position data if present for pos_key in ["x", "y"]: if pos_key in data: pos_data = data[pos_key] if not isinstance(pos_data, np.ndarray): print(f"Position data '{pos_key}' must be a numpy array") return False if pos_data.size == 0: print(f"Position data '{pos_key}' is empty") return False if not np.isfinite(pos_data).all(): print(f"Position data '{pos_key}' contains non-finite values") return False if pos_data.shape != t_data.shape: print( f"Position data '{pos_key}' shape {pos_data.shape} doesn't match time data shape {t_data.shape}" ) return False return True
[docs] def preprocess_spike_data( spike_data: list | np.ndarray, time_window: tuple[float, float] | None = None, min_spike_count: int = 10, ) -> np.ndarray | None: """ Preprocess spike data for analysis. Parameters ---------- spike_data : list or ndarray Raw spike data. time_window : tuple, optional (start, end) time window to filter spikes. min_spike_count : int Minimum number of spikes required per neuron. Returns ------- ndarray or None Processed spike data, or None if processing fails. """ # Convert to numpy array if needed if isinstance(spike_data, list): spike_data = np.array(spike_data, dtype=object) if spike_data.size == 0: print("Spike data is empty") return None # Filter by time window if specified if time_window is not None: start_time, end_time = time_window filtered_spikes = [] for neuron_spikes in spike_data: if len(neuron_spikes) > 0: mask = (neuron_spikes >= start_time) & (neuron_spikes <= end_time) filtered_spikes.append(neuron_spikes[mask]) else: filtered_spikes.append(np.array([])) spike_data = np.array(filtered_spikes, dtype=object) # Filter neurons with insufficient spikes valid_neurons = [] for neuron_spikes in spike_data: if len(neuron_spikes) >= min_spike_count: valid_neurons.append(neuron_spikes) if len(valid_neurons) == 0: print(f"No neurons meet minimum spike count requirement ({min_spike_count})") return None if len(valid_neurons) < len(spike_data): print(f"Filtered {len(spike_data) - len(valid_neurons)} neurons with insufficient spikes") return np.array(valid_neurons, dtype=object)
[docs] def get_data_summary(data: np.ndarray | dict[str, Any]) -> dict[str, Any]: """ Get summary statistics for experimental data. Parameters ---------- data : ndarray or dict ROI data (ndarray) or grid data (dict). Returns ------- dict Summary statistics. """ summary = {} if isinstance(data, np.ndarray): # ROI data summary summary["type"] = "roi_data" summary["shape"] = data.shape summary["size"] = data.size summary["dtype"] = str(data.dtype) summary["min"] = float(np.min(data)) summary["max"] = float(np.max(data)) summary["mean"] = float(np.mean(data)) summary["std"] = float(np.std(data)) summary["has_nan"] = bool(np.isnan(data).any()) summary["has_inf"] = bool(np.isinf(data).any()) elif isinstance(data, dict): # Grid data summary summary["type"] = "grid_data" summary["keys"] = list(data.keys()) if "spike" in data: spike_data = data["spike"] summary["n_neurons"] = len(spike_data) spike_counts = [len(neuron_spikes) for neuron_spikes in spike_data] summary["spike_counts"] = { "min": int(np.min(spike_counts)), "max": int(np.max(spike_counts)), "mean": float(np.mean(spike_counts)), "total": int(np.sum(spike_counts)), } if "t" in data: t_data = data["t"] summary["time_data"] = { "length": len(t_data), "duration": float(t_data.max() - t_data.min()), "sampling_rate": float(1.0 / np.mean(np.diff(t_data))) if len(t_data) > 1 else None, } for pos_key in ["x", "y"]: if pos_key in data: pos_data = data[pos_key] summary[f"{pos_key}_data"] = { "min": float(np.min(pos_data)), "max": float(np.max(pos_data)), "range": float(np.max(pos_data) - np.min(pos_data)), } else: summary["type"] = "unknown" summary["error"] = "Unsupported data type" return summary