Source code for src.canns.analyzer.slow_points.checkpoint

"""Checkpoint utilities for saving and loading trained RNN models using BrainPy's built-in checkpointing."""

import os

import brainpy as bp

__all__ = ["save_checkpoint", "load_checkpoint"]


[docs] def save_checkpoint(model: bp.DynamicalSystem, filepath: str) -> None: """Save model parameters to a checkpoint file using BrainPy checkpointing. Args: model: BrainPy model to save. filepath: Path to save the checkpoint file. Example: >>> from canns.analyzer.slow_points import save_checkpoint >>> save_checkpoint(rnn, "my_model.msgpack") Saved checkpoint to: my_model.msgpack """ # Extract all states from model (parameters and state variables) states = bp.save_state(model) # Save to disk using BrainPy's checkpoint system (automatically creates parent directories) bp.checkpoints.save_pytree(filepath, states, overwrite=True) # Print confirmation message print(f"Saved checkpoint to: {filepath}")
[docs] def load_checkpoint(model: bp.DynamicalSystem, filepath: str) -> bool: """Load model parameters from a checkpoint file using BrainPy checkpointing. Args: model: BrainPy model to load parameters into. filepath: Path to the checkpoint file. Returns: True if checkpoint was loaded successfully, False otherwise. Example: >>> from canns.analyzer.slow_points import load_checkpoint >>> if load_checkpoint(rnn, "my_model.msgpack"): ... print("Loaded successfully") ... else: ... print("No checkpoint found") Loaded checkpoint from: my_model.msgpack Loaded successfully """ # Check if file exists if not os.path.exists(filepath): return False try: # Load state dictionary from disk state_dict = bp.checkpoints.load_pytree(filepath) # Load state into model bp.load_state(model, state_dict) # Print confirmation message print(f"Loaded checkpoint from: {filepath}") return True except (ValueError, FileNotFoundError, OSError): # Handle file not found, corrupt file, or permission errors return False