deepxube.cli
1from typing import List, Optional, Tuple, cast 2import argparse 3from argparse import ArgumentParser 4 5from deepxube.train_cli import parser_train 6from deepxube.base.domain import DomainParser, StateGoalVizable, StringToAct, State, Action, Goal 7from deepxube.base.heuristic import HeurNNetParser 8from deepxube.factories.domain_factory import get_all_domain_names, get_domain_parser 9from deepxube.factories.nnet_input_factory import get_domain_nnet_input_keys 10from deepxube.factories.heuristic_factory import get_all_heur_nnet_names, get_heur_nnet_parser 11from deepxube.utils.command_line_utils import get_domain_from_arg 12 13import matplotlib.pyplot as plt 14 15import textwrap 16 17 18def domain_info(args: argparse.Namespace) -> None: 19 domain_names: List[str] = get_all_domain_names() 20 for domain_name in domain_names: 21 print(f"Domain: {domain_name}") 22 parser: Optional[DomainParser] = get_domain_parser(domain_name) 23 if parser is not None: 24 print(textwrap.indent("Parser: " + parser.help(), '\t')) 25 26 nnet_input_t_keys: List[Tuple[str, str]] = get_domain_nnet_input_keys(domain_name) 27 if len(nnet_input_t_keys) > 0: 28 print(f"\tNNet Inputs: {', '.join(nnet_input_t_key[1] for nnet_input_t_key in nnet_input_t_keys)}") 29 30 31def heur_info(args: argparse.Namespace) -> None: 32 heur_nnet_names: List[str] = get_all_heur_nnet_names() 33 for heur_nnet_name in heur_nnet_names: 34 print(f"Heur NNet: {heur_nnet_name}") 35 parser: Optional[HeurNNetParser] = get_heur_nnet_parser(heur_nnet_name) 36 if parser is not None: 37 print(textwrap.indent("Parser: " + parser.help(), '\t')) 38 39 40def viz(args: argparse.Namespace) -> None: 41 domain, domain_name = get_domain_from_arg(args.domain) 42 assert isinstance(domain, StateGoalVizable) 43 states, goals = domain.get_start_goal_pairs([args.steps]) 44 state: State = states[0] 45 goal: Goal = goals[0] 46 fig = plt.figure(figsize=(5, 5)) 47 domain.visualize_state_goal(state, goal, fig) 48 print(f"Goal Reached: {domain.is_solved([state], [goal])[0]}") 49 50 if isinstance(domain, StringToAct): 51 plt.show(block=False) 52 while True: 53 act_str = input("Write action (make blank to quit): ") 54 if len(act_str) == 0: 55 break 56 action: Optional[Action] = domain.string_to_action(act_str) 57 if action is None: 58 print(f"No action {act_str}") 59 else: 60 states_next, tcs = domain.next_state([state], [action]) 61 state = states_next[0] 62 tc: float = tcs[0] 63 fig.clear() 64 cast(StateGoalVizable, domain).visualize_state_goal(state, goal, fig) 65 print(f"Transition cost: {tc}") 66 fig.canvas.draw() 67 68 print(f"Goal Reached: {domain.is_solved([state], [goal])[0]}") 69 else: 70 plt.show(block=True) 71 72 73def main() -> None: 74 parser = ArgumentParser(prog="deepxube", description="Solve pathfinding problems with deep reinforcement learning " 75 "and heuristic search.", 76 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 77 subparsers = parser.add_subparsers(help="") 78 79 # train 80 parser_tr: ArgumentParser = subparsers.add_parser('train', help="Train a heuristic function.", 81 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 82 parser_train(parser_tr) 83 84 # domain info 85 parser_domain_info: ArgumentParser = subparsers.add_parser('domain_info', help="Print information on domains that " 86 "deepxube has registered. " 87 "Put user-defined definitions of " 88 "domains in './domains/'") 89 _parser_domain_info(parser_domain_info) 90 91 # heuristic info 92 parser_heur_info: ArgumentParser = subparsers.add_parser('heuristic_info', help="Print information on neural network " 93 "representations of heuristic functions " 94 "that deepxube has registered. " 95 "Put user-defined definitions of " 96 "heuristic neural networks in " 97 "'./heuristics/'") 98 _parser_heur_info(parser_heur_info) 99 100 # visualization 101 parser_viz: ArgumentParser = subparsers.add_parser('viz', help="Visualize states/goals") 102 _parse_viz_info(parser_viz) 103 104 args = parser.parse_args() 105 106 args.func(args) 107 108 109def _parser_domain_info(parser: ArgumentParser) -> None: 110 parser.set_defaults(func=domain_info) 111 112 113def _parser_heur_info(parser: ArgumentParser) -> None: 114 parser.set_defaults(func=heur_info) 115 116 117def _parse_viz_info(parser: ArgumentParser) -> None: 118 parser.add_argument('--domain', type=str, required=True, help="Domain name and arguments.") 119 parser.add_argument('--steps', type=int, default=0, help="Number of steps to take to generate problem instnace.") 120 parser.set_defaults(func=viz)
def
domain_info(args: argparse.Namespace) -> None:
19def domain_info(args: argparse.Namespace) -> None: 20 domain_names: List[str] = get_all_domain_names() 21 for domain_name in domain_names: 22 print(f"Domain: {domain_name}") 23 parser: Optional[DomainParser] = get_domain_parser(domain_name) 24 if parser is not None: 25 print(textwrap.indent("Parser: " + parser.help(), '\t')) 26 27 nnet_input_t_keys: List[Tuple[str, str]] = get_domain_nnet_input_keys(domain_name) 28 if len(nnet_input_t_keys) > 0: 29 print(f"\tNNet Inputs: {', '.join(nnet_input_t_key[1] for nnet_input_t_key in nnet_input_t_keys)}")
def
heur_info(args: argparse.Namespace) -> None:
32def heur_info(args: argparse.Namespace) -> None: 33 heur_nnet_names: List[str] = get_all_heur_nnet_names() 34 for heur_nnet_name in heur_nnet_names: 35 print(f"Heur NNet: {heur_nnet_name}") 36 parser: Optional[HeurNNetParser] = get_heur_nnet_parser(heur_nnet_name) 37 if parser is not None: 38 print(textwrap.indent("Parser: " + parser.help(), '\t'))
def
viz(args: argparse.Namespace) -> None:
41def viz(args: argparse.Namespace) -> None: 42 domain, domain_name = get_domain_from_arg(args.domain) 43 assert isinstance(domain, StateGoalVizable) 44 states, goals = domain.get_start_goal_pairs([args.steps]) 45 state: State = states[0] 46 goal: Goal = goals[0] 47 fig = plt.figure(figsize=(5, 5)) 48 domain.visualize_state_goal(state, goal, fig) 49 print(f"Goal Reached: {domain.is_solved([state], [goal])[0]}") 50 51 if isinstance(domain, StringToAct): 52 plt.show(block=False) 53 while True: 54 act_str = input("Write action (make blank to quit): ") 55 if len(act_str) == 0: 56 break 57 action: Optional[Action] = domain.string_to_action(act_str) 58 if action is None: 59 print(f"No action {act_str}") 60 else: 61 states_next, tcs = domain.next_state([state], [action]) 62 state = states_next[0] 63 tc: float = tcs[0] 64 fig.clear() 65 cast(StateGoalVizable, domain).visualize_state_goal(state, goal, fig) 66 print(f"Transition cost: {tc}") 67 fig.canvas.draw() 68 69 print(f"Goal Reached: {domain.is_solved([state], [goal])[0]}") 70 else: 71 plt.show(block=True)
def
main() -> None:
74def main() -> None: 75 parser = ArgumentParser(prog="deepxube", description="Solve pathfinding problems with deep reinforcement learning " 76 "and heuristic search.", 77 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 78 subparsers = parser.add_subparsers(help="") 79 80 # train 81 parser_tr: ArgumentParser = subparsers.add_parser('train', help="Train a heuristic function.", 82 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 83 parser_train(parser_tr) 84 85 # domain info 86 parser_domain_info: ArgumentParser = subparsers.add_parser('domain_info', help="Print information on domains that " 87 "deepxube has registered. " 88 "Put user-defined definitions of " 89 "domains in './domains/'") 90 _parser_domain_info(parser_domain_info) 91 92 # heuristic info 93 parser_heur_info: ArgumentParser = subparsers.add_parser('heuristic_info', help="Print information on neural network " 94 "representations of heuristic functions " 95 "that deepxube has registered. " 96 "Put user-defined definitions of " 97 "heuristic neural networks in " 98 "'./heuristics/'") 99 _parser_heur_info(parser_heur_info) 100 101 # visualization 102 parser_viz: ArgumentParser = subparsers.add_parser('viz', help="Visualize states/goals") 103 _parse_viz_info(parser_viz) 104 105 args = parser.parse_args() 106 107 args.func(args)