Source code for src.canns.analyzer.brain_inspired.hopfield

"""Hopfield network analysis tools."""

from __future__ import annotations

import jax.numpy as jnp
import numpy as np

__all__ = ["HopfieldAnalyzer"]


[docs] class HopfieldAnalyzer: """ Analyzer for Hopfield associative memory networks. Provides diagnostic and analysis tools for Hopfield networks including: - Pattern storage capacity estimation - Energy landscape analysis - Overlap metrics for pattern retrieval - Recall quality diagnostics The Hopfield network stores patterns as attractors in an energy landscape. Energy function: E = -0.5 * s^T W s Reference: Hopfield, J. J. (1982). Neural networks and physical systems with emergent collective computational abilities. PNAS, 79(8), 2554-2558. """ def __init__(self, model, stored_patterns: list | None = None): """ Initialize Hopfield analyzer. Args: model: The Hopfield network model to analyze stored_patterns: List of patterns stored in the network (optional) """
[docs] self.model = model
[docs] self.stored_patterns = stored_patterns if stored_patterns is not None else []
self._pattern_energies = [] # Compute energies if patterns provided if len(self.stored_patterns) > 0: self.compute_pattern_energies()
[docs] def set_patterns(self, patterns: list): """ Set the stored patterns and compute their energies. Args: patterns: List of patterns stored in the network """ self.stored_patterns = [jnp.asarray(p, dtype=jnp.float32) for p in patterns] self.compute_pattern_energies()
[docs] def compute_pattern_energies(self): """Compute energy for each stored pattern.""" self._pattern_energies = [] # Get weight matrix weight_attr = getattr(self.model, "weight_attr", "W") if callable(weight_attr): weight_attr = weight_attr() weight_param = getattr(self.model, weight_attr) W = weight_param.value for pattern in self.stored_patterns: # E = -0.5 * s^T W s energy = -0.5 * jnp.dot(pattern, jnp.dot(W, pattern)) self._pattern_energies.append(float(energy))
@property
[docs] def pattern_energies(self) -> list[float]: """Get energies of stored patterns.""" return self._pattern_energies
[docs] def compute_overlap(self, pattern1: jnp.ndarray, pattern2: jnp.ndarray) -> float: """ Compute normalized overlap between two patterns. Args: pattern1: First pattern pattern2: Second pattern Returns: Overlap value between -1 and 1 """ p1 = jnp.asarray(pattern1, dtype=jnp.float32) p2 = jnp.asarray(pattern2, dtype=jnp.float32) return float(jnp.dot(p1, p2) / len(p1))
[docs] def compute_energy(self, pattern: jnp.ndarray) -> float: """ Compute energy of a given pattern. Args: pattern: Pattern to compute energy for Returns: Energy value E = -0.5 * s^T W s """ # Get weight matrix weight_attr = getattr(self.model, "weight_attr", "W") if callable(weight_attr): weight_attr = weight_attr() weight_param = getattr(self.model, weight_attr) W = weight_param.value p = jnp.asarray(pattern, dtype=jnp.float32) return float(-0.5 * jnp.dot(p, jnp.dot(W, p)))
[docs] def analyze_recall(self, input_pattern: jnp.ndarray, output_pattern: jnp.ndarray) -> dict: """ Analyze pattern recall quality. Args: input_pattern: Input (noisy) pattern output_pattern: Recalled pattern Returns: Dictionary with diagnostic metrics: - best_match_idx: Index of best matching stored pattern - best_match_overlap: Overlap with best matching pattern - input_output_overlap: Overlap between input and output - output_energy: Energy of the recalled pattern """ diagnostics = {} # Find best matching stored pattern if len(self.stored_patterns) > 0: overlaps = [ self.compute_overlap(output_pattern, stored) for stored in self.stored_patterns ] best_idx = int(np.argmax(overlaps)) diagnostics["best_match_idx"] = best_idx diagnostics["best_match_overlap"] = overlaps[best_idx] # Input-output overlap diagnostics["input_output_overlap"] = self.compute_overlap(input_pattern, output_pattern) # Energy of recalled pattern diagnostics["output_energy"] = self.compute_energy(output_pattern) return diagnostics
[docs] def estimate_capacity(self) -> int: """ Estimate theoretical storage capacity of the network. Uses the rule of thumb: capacity ≈ N / (4 * ln(N)) where N is the number of neurons. Returns: Estimated number of patterns that can be reliably stored """ if hasattr(self.model, "storage_capacity"): return self.model.storage_capacity # Default estimate: N / (4 * ln(N)) n = self.model.num_neurons if hasattr(self.model, "num_neurons") else 100 return max(1, int(n / (4 * np.log(n))))
[docs] def get_statistics(self) -> dict: """ Get comprehensive statistics about stored patterns. Returns: Dictionary with pattern statistics: - num_patterns: Number of stored patterns - capacity_estimate: Theoretical capacity estimate - capacity_usage: Fraction of capacity used - mean_pattern_energy: Mean energy of stored patterns - std_pattern_energy: Standard deviation of energies - min_pattern_energy: Minimum energy - max_pattern_energy: Maximum energy """ stats = { "num_patterns": len(self.stored_patterns), "capacity_estimate": self.estimate_capacity(), "capacity_usage": len(self.stored_patterns) / max(1, self.estimate_capacity()), } if len(self._pattern_energies) > 0: stats["mean_pattern_energy"] = float(np.mean(self._pattern_energies)) stats["std_pattern_energy"] = float(np.std(self._pattern_energies)) stats["min_pattern_energy"] = float(np.min(self._pattern_energies)) stats["max_pattern_energy"] = float(np.max(self._pattern_energies)) return stats
[docs] def compute_weight_symmetry_error(self) -> float: """ Compute the symmetry error of the weight matrix. Hopfield networks require symmetric weights (W_ij = W_ji). This metric quantifies how much the weight matrix deviates from symmetry. Returns: Symmetry error as ||W - W^T||_F / ||W||_F """ weight_attr = getattr(self.model, "weight_attr", "W") if callable(weight_attr): weight_attr = weight_attr() weight_param = getattr(self.model, weight_attr) W = weight_param.value # Frobenius norm of asymmetry ("fro" is the standard numpy/jax parameter) asymmetry = W - W.T symmetry_error = float(jnp.linalg.norm(asymmetry, "fro") / jnp.linalg.norm(W, "fro")) return symmetry_error