src.canns.task.navigation_base

Base navigation task with geodesic distance computation capabilities.

Classes

BaseNavigationTask

Base class for navigation tasks with geodesic distance computation support.

GeodesicDistanceResult

MovementCostGrid

Module Contents

class src.canns.task.navigation_base.BaseNavigationTask(start_pos=(2.5, 2.5), width=5, height=5, dimensionality='2D', boundary_conditions='solid', scale=None, dx=0.01, grid_dx=None, grid_dy=None, boundary=None, walls=None, holes=None, objects=None, dt=None, speed_mean=0.04, speed_std=0.016, speed_coherence_time=0.7, rotational_velocity_coherence_time=0.08, rotational_velocity_std=120 * np.pi / 180, head_direction_smoothing_timescale=0.15, initial_head_direction=None, thigmotaxis=0.5, wall_repel_distance=0.1, wall_repel_strength=1.0, data_class=None)[source]

Bases: src.canns.task._base.Task

Base class for navigation tasks with geodesic distance computation support.

This class provides common functionality for both open-loop and closed-loop navigation tasks, including environment setup, agent initialization, and geodesic distance computation on discretized grids.

Initializes the Task instance.

Parameters:

data_class (type, optional) – A dataclass type for structured data. If provided, the task will use this class to structure the loaded or generated data.

build_movement_cost_grid(*, refresh=False)[source]

Construct a grid-based movement cost map for the configured environment.

A cell weight of 1 indicates free space, while INT32_MAX marks an impassable cell (intersecting a wall/hole or lying outside the boundary).

Parameters:

refresh (bool) – Force recomputation even if a cached grid is available.

Returns:

MovementCostGrid describing the discretised environment.

Return type:

MovementCostGrid

compute_geodesic_distance_matrix(dx=None, dy=None, *, refresh=False)[source]

Compute pairwise geodesic distances between traversable grid cells.

The computation treats each traversable cell (weight 1) as a graph node connected to its four axis-aligned neighbours. Horizontal steps cost dx and vertical steps cost dy. Impassable cells (INT32_MAX) are ignored.

When Numba is available, this method uses parallelized Dijkstra computation across CPU cores for significant speedup (typically 4-8x on multi-core systems). Without Numba, it falls back to sequential Python implementation with a progress bar.

Parameters:
  • dx (float | None) – Grid cell width along the x axis. When None the existing grid_dx attribute is used.

  • dy (float | None) – Grid cell height along the y axis. When None the existing grid_dy attribute is used.

  • refresh (bool) – Force recomputation even if cached results exist.

Returns:

GeodesicDistanceResult containing the distance matrix and metadata.

Return type:

GeodesicDistanceResult

Note

The parallel Numba implementation cannot show a progress bar during computation, but prints start/end messages instead.

get_geodesic_index_by_pos(pos, *, refresh=False)[source]

Get the index of the grid cell containing the given position.

Parameters:
Returns:

Index of the grid cell in the geodesic distance matrix, or None if the position is out of bounds or in an impassable cell.

Return type:

int | None

set_grid_resolution(dx, dy)[source]

Update the stored grid resolution and invalidate cached data.

show_data(show=True, save_path=None, *, overlay_movement_cost=False, cost_grid=None, free_color='#f8f9fa', blocked_color='#f94144', gridline_color='#2b2d42', cost_alpha=0.6, show_colorbar=False, cost_legend_loc=None)[source]

Display the agent’s trajectory with optional movement cost grid overlay.

Parameters:
  • show (bool) – Whether to display the plot.

  • save_path (str | None) – Path to save the figure. If None, the figure is not saved.

  • overlay_movement_cost (bool) – Whether to overlay the movement cost grid.

  • cost_grid (MovementCostGrid | None) – Pre-computed cost grid. If None and overlay_movement_cost is True, the grid will be built on demand.

  • free_color (str) – Color for free (accessible) cells in the cost grid.

  • blocked_color (str) – Color for blocked (inaccessible) cells in the cost grid.

  • gridline_color (str) – Color for grid lines.

  • cost_alpha (float) – Transparency of the cost grid overlay (0=transparent, 1=opaque).

  • show_colorbar (bool) – Whether to show a colorbar for the cost grid.

  • cost_legend_loc (str | None) – Location of the legend for the cost grid (e.g., ‘upper right’). If None, no legend is shown.

show_geodesic_distance_matrix(dx=None, dy=None, *, show=True, save_path=None, cmap='viridis', normalize=False, colorbar=True, refresh=False)[source]

Visualise the geodesic distance matrix for the discretised environment.

agent[source]
agent_params[source]
aspect = 1.0[source]
boundary[source]
boundary_conditions = 'solid'[source]
cost_grid: MovementCostGrid | None = None[source]
dimensionality = ''[source]
dt = None[source]
dx = 0.01[source]
env[source]
env_params[source]
geodesic_result: GeodesicDistanceResult | None = None[source]
grid_dx = 0.01[source]
grid_dy = 0.01[source]
head_direction_smoothing_timescale = 0.15[source]
height = 5[source]
holes[source]
initial_head_direction = None[source]
objects[source]
rotational_velocity_coherence_time = 0.08[source]
rotational_velocity_std[source]
scale = 5[source]
speed_coherence_time = 0.7[source]
speed_mean = 0.04[source]
speed_std = 0.016[source]
start_pos = (2.5, 2.5)[source]
thigmotaxis = 0.5[source]
wall_repel_distance = 0.1[source]
wall_repel_strength = 1.0[source]
walls[source]
width = 5[source]
class src.canns.task.navigation_base.GeodesicDistanceResult[source]
accessible_indices: numpy.ndarray[source]
cost_grid: MovementCostGrid[source]
distances: numpy.ndarray[source]
class src.canns.task.navigation_base.MovementCostGrid[source]
get_cell_index(pos)[source]

Get the geodesic index of the grid cell containing the given position.

This method is JAX-compatible and can be used inside jitted functions.

Parameters:

pos (collections.abc.Sequence[float]) – (x, y) coordinates of the position.

Returns:

Index of the grid cell in the accessible_indices array, or -1 if the position is out of bounds or in an impassable cell.

Return type:

int

Note

Returns -1 (instead of None) for JAX compatibility. The caller should check for negative values to detect invalid positions.

accessible_indices: numpy.ndarray | None = None[source]
property accessible_mask: numpy.ndarray[source]
costs: numpy.ndarray[source]
dx: float[source]
dy: float[source]
property shape: tuple[int, int][source]
property x_centers: numpy.ndarray[source]
x_edges: numpy.ndarray[source]
property y_centers: numpy.ndarray[source]
y_edges: numpy.ndarray[source]