src.canns.task.navigation_base¶
Base navigation task with geodesic distance computation capabilities.
Classes¶
Base class for navigation tasks with geodesic distance computation support. |
|
Module Contents¶
- class src.canns.task.navigation_base.BaseNavigationTask(start_pos=(2.5, 2.5), width=5, height=5, dimensionality='2D', boundary_conditions='solid', scale=None, dx=0.01, grid_dx=None, grid_dy=None, boundary=None, walls=None, holes=None, objects=None, dt=None, speed_mean=0.04, speed_std=0.016, speed_coherence_time=0.7, rotational_velocity_coherence_time=0.08, rotational_velocity_std=120 * np.pi / 180, head_direction_smoothing_timescale=0.15, initial_head_direction=None, thigmotaxis=0.5, wall_repel_distance=0.1, wall_repel_strength=1.0, data_class=None)[source]¶
Bases:
src.canns.task._base.TaskBase class for navigation tasks with geodesic distance computation support.
This class provides common functionality for both open-loop and closed-loop navigation tasks, including environment setup, agent initialization, and geodesic distance computation on discretized grids.
Initializes the Task instance.
- Parameters:
data_class (type, optional) – A dataclass type for structured data. If provided, the task will use this class to structure the loaded or generated data.
- build_movement_cost_grid(*, refresh=False)[source]¶
Construct a grid-based movement cost map for the configured environment.
A cell weight of
1indicates free space, whileINT32_MAXmarks an impassable cell (intersecting a wall/hole or lying outside the boundary).- Parameters:
refresh (bool) – Force recomputation even if a cached grid is available.
- Returns:
MovementCostGrid describing the discretised environment.
- Return type:
- compute_geodesic_distance_matrix(dx=None, dy=None, *, refresh=False)[source]¶
Compute pairwise geodesic distances between traversable grid cells.
The computation treats each traversable cell (weight
1) as a graph node connected to its four axis-aligned neighbours. Horizontal steps costdxand vertical steps costdy. Impassable cells (INT32_MAX) are ignored.When Numba is available, this method uses parallelized Dijkstra computation across CPU cores for significant speedup (typically 4-8x on multi-core systems). Without Numba, it falls back to sequential Python implementation with a progress bar.
- Parameters:
- Returns:
GeodesicDistanceResult containing the distance matrix and metadata.
- Return type:
Note
The parallel Numba implementation cannot show a progress bar during computation, but prints start/end messages instead.
- get_geodesic_index_by_pos(pos, *, refresh=False)[source]¶
Get the index of the grid cell containing the given position.
- Parameters:
pos (collections.abc.Sequence[float]) – (x, y) coordinates of the position.
refresh (bool) – Recompute the cached grid before querying the index.
- Returns:
Index of the grid cell in the geodesic distance matrix, or None if the position is out of bounds or in an impassable cell.
- Return type:
int | None
- show_data(show=True, save_path=None, *, overlay_movement_cost=False, cost_grid=None, free_color='#f8f9fa', blocked_color='#f94144', gridline_color='#2b2d42', cost_alpha=0.6, show_colorbar=False, cost_legend_loc=None)[source]¶
Display the agent’s trajectory with optional movement cost grid overlay.
- Parameters:
show (bool) – Whether to display the plot.
save_path (str | None) – Path to save the figure. If None, the figure is not saved.
overlay_movement_cost (bool) – Whether to overlay the movement cost grid.
cost_grid (MovementCostGrid | None) – Pre-computed cost grid. If None and overlay_movement_cost is True, the grid will be built on demand.
free_color (str) – Color for free (accessible) cells in the cost grid.
blocked_color (str) – Color for blocked (inaccessible) cells in the cost grid.
gridline_color (str) – Color for grid lines.
cost_alpha (float) – Transparency of the cost grid overlay (0=transparent, 1=opaque).
show_colorbar (bool) – Whether to show a colorbar for the cost grid.
cost_legend_loc (str | None) – Location of the legend for the cost grid (e.g., ‘upper right’). If None, no legend is shown.
- show_geodesic_distance_matrix(dx=None, dy=None, *, show=True, save_path=None, cmap='viridis', normalize=False, colorbar=True, refresh=False)[source]¶
Visualise the geodesic distance matrix for the discretised environment.
- cost_grid: MovementCostGrid | None = None[source]¶
- geodesic_result: GeodesicDistanceResult | None = None[source]¶
- class src.canns.task.navigation_base.GeodesicDistanceResult[source]¶
- accessible_indices: numpy.ndarray[source]¶
- cost_grid: MovementCostGrid[source]¶
- distances: numpy.ndarray[source]¶
- class src.canns.task.navigation_base.MovementCostGrid[source]¶
- get_cell_index(pos)[source]¶
Get the geodesic index of the grid cell containing the given position.
This method is JAX-compatible and can be used inside jitted functions.
- Parameters:
pos (collections.abc.Sequence[float]) – (x, y) coordinates of the position.
- Returns:
Index of the grid cell in the accessible_indices array, or -1 if the position is out of bounds or in an impassable cell.
- Return type:
Note
Returns -1 (instead of None) for JAX compatibility. The caller should check for negative values to detect invalid positions.
- accessible_indices: numpy.ndarray | None = None[source]¶
- property accessible_mask: numpy.ndarray[source]¶
- costs: numpy.ndarray[source]¶
- property x_centers: numpy.ndarray[source]¶
- x_edges: numpy.ndarray[source]¶
- property y_centers: numpy.ndarray[source]¶
- y_edges: numpy.ndarray[source]¶