Creating a Custom Domain and Heuristic Function¶
We will create a simple grid domain where the agent can move up, down, left, or right along a two-dimensional grid to reach a goal square. We will create neural network inputs for a neural network that DeepXube provides as well as a neural network input for our own custom neural network.
In the directory in which you run deepxube, create a domains/grid_tutorial.py directory.
DeepXube automatically looks in the domains/ folder to see what is registered.
Implementation¶
The entire domain file is here. This includes the states, actions, goals, domain, neural network inputs, custom neural network, and parsers. This file will be explained part-by-part.
from typing import List, Tuple, Optional, Type
import numpy as np
from torch import nn, Tensor
from deepxube.base.factory import DelimParser
from deepxube.base.domain import State, Action, Goal, ActsEnumFixed, StartGoalWalkable, StateGoalVizable, StringToAct
from deepxube.base.nnet_input import StateGoalIn, StateGoalActFixIn, StateGoalActIn, FlatIn
from deepxube.base.heuristic import HeurNNet
from deepxube.factories.heuristic_factory import heuristic_factory
from deepxube.factories.domain_factory import domain_factory
from deepxube.factories.nnet_input_factory import register_nnet_input
from deepxube.nnet.pytorch_models import Conv2dModel, FullyConnectedModel
from numpy.typing import NDArray
import random
from matplotlib.figure import Figure
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
# state, action, goal
class GridState(State):
def __init__(self, robot_x: int, robot_y: int):
self.robot_x: int = robot_x
self.robot_y: int = robot_y
def __hash__(self) -> int:
return hash(self.robot_x + self.robot_y)
def __eq__(self, other: object) -> bool:
if isinstance(other, GridState):
return (self.robot_x == other.robot_x) and (self.robot_y == other.robot_y)
return NotImplemented
class GridGoal(Goal):
def __init__(self, robot_x: int, robot_y: int):
self.robot_x: int = robot_x
self.robot_y: int = robot_y
class GridAction(Action):
def __init__(self, action: int):
self.action = action
def __hash__(self) -> int:
return self.action
def __eq__(self, other: object) -> bool:
if isinstance(other, GridAction):
return self.action == other.action
return NotImplemented
def __repr__(self) -> str:
return ["UP", "DOWN", "LEFT", "RIGHT"][self.action]
# domain definition
@domain_factory.register_class("grid_tut")
class Grid(ActsEnumFixed[GridState, GridAction, GridGoal], StartGoalWalkable[GridState, GridAction, GridGoal],
StateGoalVizable[GridState, GridAction, GridGoal], StringToAct[GridState, GridAction, GridGoal]):
def __init__(self, dim: int = 7):
super().__init__()
self.dim: int = dim
self.actions_fixed: List[GridAction] = [GridAction(x) for x in [0, 1, 2, 3]]
def is_solved(self, states: List[GridState], goals: List[GridGoal]) -> List[bool]:
return [(state.robot_x == goal.robot_x) and (state.robot_y == goal.robot_y) for state, goal in zip(states, goals)]
def sample_start_states(self, num_states: int) -> List[GridState]:
return [GridState(random.randint(0, self.dim - 1), random.randint(0, self.dim - 1)) for _ in range(num_states)]
def next_state(self, states: List[GridState], actions: List[GridAction]) -> Tuple[List[GridState], List[float]]:
states_next: List[GridState] = []
for state, action in zip(states, actions):
if action.action == 1: # up
states_next.append(GridState(min(state.robot_x + 1, self.dim - 1), state.robot_y))
elif action.action == 0: # down
states_next.append(GridState(max(state.robot_x - 1, 0), state.robot_y))
elif action.action == 3: # left
states_next.append(GridState(state.robot_x, min(state.robot_y + 1, self.dim - 1)))
elif action.action == 2: # right
states_next.append(GridState(state.robot_x, max(state.robot_y - 1, 0)))
return states_next, [1.0] * len(states_next)
def sample_goal_from_state(self, states_start: Optional[List[GridState]], states_goal: List[GridState]) -> List[GridGoal]:
return [GridGoal(state_goal.robot_x, state_goal.robot_y) for state_goal in states_goal]
def visualize_state_goal(self, state: GridState, goal: GridGoal, fig: Figure) -> None:
ax = plt.axes()
grid: NDArray = np.zeros((self.dim, self.dim))
grid[goal.robot_x, goal.robot_y] = 2
grid[state.robot_x, state.robot_y] = 1
ax.imshow(grid, cmap=ListedColormap(["white", "black", "green"]), origin="upper")
fig.add_axes(ax)
def string_to_action(self, act_str: str) -> Optional[GridAction]:
if act_str in {"w", "s", "a", "d"}:
return GridAction(["w", "s", "a", "d"].index(act_str))
else:
return None
def string_to_action_help(self) -> str:
return "w, s, a, or d for up, down, left, and right, respectively."
def get_actions_fixed(self) -> List[GridAction]:
return self.actions_fixed.copy()
def __repr__(self) -> str:
return f"Grid(dim={self.dim})"
# domain parser definition
@domain_factory.register_parser("grid_tut")
class GridParser(DelimParser):
def __init__(self) -> None:
super().__init__()
self.add_argument("d", "dim", int, "dimensionality of grid")
@property
def delim(self) -> str:
return "_"
# gridflatin definition
@register_nnet_input("grid_tut", "grid_flat_in")
class GridFlatIn(StateGoalIn[Grid, GridState, GridGoal], FlatIn[Grid]):
def get_input_info(self) -> Tuple[List[int], List[int]]:
return [4], [self.domain.dim]
def to_np(self, states: List[GridState], goals: List[GridGoal]) -> List[NDArray]:
return [np.stack([np.stack([state.robot_x for state in states]), np.stack([state.robot_y for state in states]),
np.stack([goal.robot_x for goal in goals]), np.stack([goal.robot_y for goal in goals])], axis=1)]
# gridflatinqfix definition
@register_nnet_input("grid_tut", "grid_flat_in_qfix")
class GridFlatInQFix(StateGoalActFixIn[Grid, GridState, GridGoal, GridAction], FlatIn[Grid]):
def get_input_info(self) -> Tuple[List[int], List[int]]:
return [4], [self.domain.dim]
def to_np(self, states: List[GridState], goals: List[GridGoal], actions_l: List[List[GridAction]]) -> List[NDArray]:
actions_np: NDArray = np.array([[action_i.action for action_i in actions] for actions in actions_l])
return [np.stack([np.stack([state.robot_x for state in states]), np.stack([state.robot_y for state in states]),
np.stack([goal.robot_x for goal in goals]), np.stack([goal.robot_y for goal in goals])], axis=1)] + [actions_np]
# gridflatinactin definition
@register_nnet_input("grid_tut", "grid_flat_in_actin")
class GridFlatInActIn(StateGoalActIn[Grid, GridState, GridGoal, GridAction], FlatIn[Grid]):
def get_input_info(self) -> Tuple[List[int], List[int]]:
return [4, 1], [self.domain.dim, self.domain.get_num_acts()]
def to_np(self, states: List[GridState], goals: List[GridGoal], actions: List[GridAction]) -> List[NDArray]:
actions_np: NDArray = np.expand_dims(np.array([action_i.action for action_i in actions]), 1)
return [np.stack([np.stack([state.robot_x for state in states]), np.stack([state.robot_y for state in states]),
np.stack([goal.robot_x for goal in goals]), np.stack([goal.robot_y for goal in goals])], axis=1)] + [actions_np]
# grid nnet definition
@register_nnet_input("grid_tut", "grid_nnet_input")
class GridNNetInput(StateGoalIn[Grid, GridState, GridGoal]):
def get_input_info(self) -> int:
return self.domain.dim
def to_np(self, states: List[GridState], goals: List[GridGoal]) -> List[NDArray]:
np_rep: NDArray = np.zeros((len(states), 2, self.domain.dim, self.domain.dim))
for idx, (state, goal) in enumerate(zip(states, goals)):
np_rep[idx, 0, state.robot_x, state.robot_y] = 1
np_rep[idx, 1, goal.robot_x, goal.robot_y] = 1
return [np_rep]
@heuristic_factory.register_class("gridnet_tut")
class GridNet(HeurNNet[GridNNetInput]):
@staticmethod
def nnet_input_type() -> Type[GridNNetInput]:
return GridNNetInput
def __init__(self, nnet_input: GridNNetInput, out_dim: int, q_fix: bool, chan_size: int = 8, fc_size: int = 100):
super().__init__(nnet_input, out_dim, q_fix)
# one hots
self.one_hots: nn.ModuleList = nn.ModuleList()
grid_dim: int = self.nnet_input.get_input_info()
self.heur: nn.Module = nn.Sequential(
Conv2dModel(2, [chan_size, chan_size], [3, 3], [1, 1], ["RELU", "RELU"], batch_norms=[True, True]),
nn.Flatten(),
FullyConnectedModel(grid_dim * grid_dim * chan_size, [fc_size], ["RELU"], batch_norms=[True]),
nn.Linear(fc_size, self.out_dim)
)
def _forward(self, inputs: List[Tensor]) -> Tensor:
x: Tensor = self.heur(inputs[0])
return x
@heuristic_factory.register_parser("gridnet_tut")
class GridNetParser(DelimParser):
def __init__(self) -> None:
super().__init__()
self.add_argument("ch", "chan_size", int, "number of channels")
self.add_argument("fc", "fc_size", int, "size of fully connected layer")
@property
def delim(self) -> str:
return "_"
State, Action, Goal¶
To faciliate using states with Python dictionary objects and re-identifying states during search, all State objects must implement
__hash__ and __eq__. This must also be done for Action objects.
Tip
Implementing __repr__ for action objects can be convenient since actions are printed to the screen when interacting with problem
instances with deepxube viz.
class GridState(State):
def __init__(self, robot_x: int, robot_y: int):
self.robot_x: int = robot_x
self.robot_y: int = robot_y
def __hash__(self) -> int:
return hash(self.robot_x + self.robot_y)
def __eq__(self, other: object) -> bool:
if isinstance(other, GridState):
return (self.robot_x == other.robot_x) and (self.robot_y == other.robot_y)
return NotImplemented
class GridGoal(Goal):
def __init__(self, robot_x: int, robot_y: int):
self.robot_x: int = robot_x
self.robot_y: int = robot_y
class GridAction(Action):
def __init__(self, action: int):
self.action = action
def __hash__(self) -> int:
return self.action
def __eq__(self, other: object) -> bool:
if isinstance(other, GridAction):
return self.action == other.action
return NotImplemented
def __repr__(self) -> str:
return ["UP", "DOWN", "LEFT", "RIGHT"][self.action]
# domain definition
Domain¶
@domain_factory.register_class("grid_tut")
class Grid(ActsEnumFixed[GridState, GridAction, GridGoal], StartGoalWalkable[GridState, GridAction, GridGoal],
StateGoalVizable[GridState, GridAction, GridGoal], StringToAct[GridState, GridAction, GridGoal]):
def __init__(self, dim: int = 7):
super().__init__()
self.dim: int = dim
self.actions_fixed: List[GridAction] = [GridAction(x) for x in [0, 1, 2, 3]]
def is_solved(self, states: List[GridState], goals: List[GridGoal]) -> List[bool]:
return [(state.robot_x == goal.robot_x) and (state.robot_y == goal.robot_y) for state, goal in zip(states, goals)]
def sample_start_states(self, num_states: int) -> List[GridState]:
return [GridState(random.randint(0, self.dim - 1), random.randint(0, self.dim - 1)) for _ in range(num_states)]
def next_state(self, states: List[GridState], actions: List[GridAction]) -> Tuple[List[GridState], List[float]]:
states_next: List[GridState] = []
for state, action in zip(states, actions):
if action.action == 1: # up
states_next.append(GridState(min(state.robot_x + 1, self.dim - 1), state.robot_y))
elif action.action == 0: # down
states_next.append(GridState(max(state.robot_x - 1, 0), state.robot_y))
elif action.action == 3: # left
states_next.append(GridState(state.robot_x, min(state.robot_y + 1, self.dim - 1)))
elif action.action == 2: # right
states_next.append(GridState(state.robot_x, max(state.robot_y - 1, 0)))
return states_next, [1.0] * len(states_next)
def sample_goal_from_state(self, states_start: Optional[List[GridState]], states_goal: List[GridState]) -> List[GridGoal]:
return [GridGoal(state_goal.robot_x, state_goal.robot_y) for state_goal in states_goal]
def visualize_state_goal(self, state: GridState, goal: GridGoal, fig: Figure) -> None:
ax = plt.axes()
grid: NDArray = np.zeros((self.dim, self.dim))
grid[goal.robot_x, goal.robot_y] = 2
grid[state.robot_x, state.robot_y] = 1
ax.imshow(grid, cmap=ListedColormap(["white", "black", "green"]), origin="upper")
fig.add_axes(ax)
def string_to_action(self, act_str: str) -> Optional[GridAction]:
if act_str in {"w", "s", "a", "d"}:
return GridAction(["w", "s", "a", "d"].index(act_str))
else:
return None
def string_to_action_help(self) -> str:
return "w, s, a, or d for up, down, left, and right, respectively."
def get_actions_fixed(self) -> List[GridAction]:
return self.actions_fixed.copy()
def __repr__(self) -> str:
return f"Grid(dim={self.dim})"
Domain Parser¶
@domain_factory.register_parser("grid_tut")
class GridParser(DelimParser):
def __init__(self) -> None:
super().__init__()
self.add_argument("d", "dim", int, "dimensionality of grid")
@property
def delim(self) -> str:
return "_"
Neural Network Inputs¶
Flat Input¶
@register_nnet_input("grid_tut", "grid_flat_in")
class GridFlatIn(StateGoalIn[Grid, GridState, GridGoal], FlatIn[Grid]):
def get_input_info(self) -> Tuple[List[int], List[int]]:
return [4], [self.domain.dim]
def to_np(self, states: List[GridState], goals: List[GridGoal]) -> List[NDArray]:
return [np.stack([np.stack([state.robot_x for state in states]), np.stack([state.robot_y for state in states]),
np.stack([goal.robot_x for goal in goals]), np.stack([goal.robot_y for goal in goals])], axis=1)]
Flat Input for a Q-Network with a Fixed Action Output¶
@register_nnet_input("grid_tut", "grid_flat_in_qfix")
class GridFlatInQFix(StateGoalActFixIn[Grid, GridState, GridGoal, GridAction], FlatIn[Grid]):
def get_input_info(self) -> Tuple[List[int], List[int]]:
return [4], [self.domain.dim]
def to_np(self, states: List[GridState], goals: List[GridGoal], actions_l: List[List[GridAction]]) -> List[NDArray]:
actions_np: NDArray = np.array([[action_i.action for action_i in actions] for actions in actions_l])
return [np.stack([np.stack([state.robot_x for state in states]), np.stack([state.robot_y for state in states]),
np.stack([goal.robot_x for goal in goals]), np.stack([goal.robot_y for goal in goals])], axis=1)] + [actions_np]
Flat Input for a Q-Network with the Action as an Input¶
@register_nnet_input("grid_tut", "grid_flat_in_actin")
class GridFlatInActIn(StateGoalActIn[Grid, GridState, GridGoal, GridAction], FlatIn[Grid]):
def get_input_info(self) -> Tuple[List[int], List[int]]:
return [4, 1], [self.domain.dim, self.domain.get_num_acts()]
def to_np(self, states: List[GridState], goals: List[GridGoal], actions: List[GridAction]) -> List[NDArray]:
actions_np: NDArray = np.expand_dims(np.array([action_i.action for action_i in actions]), 1)
return [np.stack([np.stack([state.robot_x for state in states]), np.stack([state.robot_y for state in states]),
np.stack([goal.robot_x for goal in goals]), np.stack([goal.robot_y for goal in goals])], axis=1)] + [actions_np]
Custom Neural Network¶
@register_nnet_input("grid_tut", "grid_nnet_input")
class GridNNetInput(StateGoalIn[Grid, GridState, GridGoal]):
def get_input_info(self) -> int:
return self.domain.dim
def to_np(self, states: List[GridState], goals: List[GridGoal]) -> List[NDArray]:
np_rep: NDArray = np.zeros((len(states), 2, self.domain.dim, self.domain.dim))
for idx, (state, goal) in enumerate(zip(states, goals)):
np_rep[idx, 0, state.robot_x, state.robot_y] = 1
np_rep[idx, 1, goal.robot_x, goal.robot_y] = 1
return [np_rep]
@heuristic_factory.register_class("gridnet_tut")
class GridNet(HeurNNet[GridNNetInput]):
@staticmethod
def nnet_input_type() -> Type[GridNNetInput]:
return GridNNetInput
def __init__(self, nnet_input: GridNNetInput, out_dim: int, q_fix: bool, chan_size: int = 8, fc_size: int = 100):
super().__init__(nnet_input, out_dim, q_fix)
# one hots
self.one_hots: nn.ModuleList = nn.ModuleList()
grid_dim: int = self.nnet_input.get_input_info()
self.heur: nn.Module = nn.Sequential(
Conv2dModel(2, [chan_size, chan_size], [3, 3], [1, 1], ["RELU", "RELU"], batch_norms=[True, True]),
nn.Flatten(),
FullyConnectedModel(grid_dim * grid_dim * chan_size, [fc_size], ["RELU"], batch_norms=[True]),
nn.Linear(fc_size, self.out_dim)
)
def _forward(self, inputs: List[Tensor]) -> Tensor:
x: Tensor = self.heur(inputs[0])
return x
@heuristic_factory.register_parser("gridnet_tut")
class GridNetParser(DelimParser):
def __init__(self) -> None:
super().__init__()
self.add_argument("ch", "chan_size", int, "number of channels")
self.add_argument("fc", "fc_size", int, "size of fully connected layer")
@property
def delim(self) -> str:
return "_"