deepxube.cli

  1from typing import List, Optional, Tuple, cast
  2import argparse
  3from argparse import ArgumentParser
  4
  5from deepxube.train_cli import parser_train
  6from deepxube.base.domain import DomainParser, StateGoalVizable, StringToAct, State, Action, Goal
  7from deepxube.base.heuristic import HeurNNetParser
  8from deepxube.factories.domain_factory import get_all_domain_names, get_domain_parser
  9from deepxube.factories.nnet_input_factory import get_domain_nnet_input_keys
 10from deepxube.factories.heuristic_factory import get_all_heur_nnet_names, get_heur_nnet_parser
 11from deepxube.utils.command_line_utils import get_domain_from_arg
 12
 13import matplotlib.pyplot as plt
 14
 15import textwrap
 16
 17
 18def domain_info(args: argparse.Namespace) -> None:
 19    domain_names: List[str] = get_all_domain_names()
 20    for domain_name in domain_names:
 21        print(f"Domain: {domain_name}")
 22        parser: Optional[DomainParser] = get_domain_parser(domain_name)
 23        if parser is not None:
 24            print(textwrap.indent("Parser: " + parser.help(), '\t'))
 25
 26        nnet_input_t_keys: List[Tuple[str, str]] = get_domain_nnet_input_keys(domain_name)
 27        if len(nnet_input_t_keys) > 0:
 28            print(f"\tNNet Inputs: {', '.join(nnet_input_t_key[1] for nnet_input_t_key in nnet_input_t_keys)}")
 29
 30
 31def heur_info(args: argparse.Namespace) -> None:
 32    heur_nnet_names: List[str] = get_all_heur_nnet_names()
 33    for heur_nnet_name in heur_nnet_names:
 34        print(f"Heur NNet: {heur_nnet_name}")
 35        parser: Optional[HeurNNetParser] = get_heur_nnet_parser(heur_nnet_name)
 36        if parser is not None:
 37            print(textwrap.indent("Parser: " + parser.help(), '\t'))
 38
 39
 40def viz(args: argparse.Namespace) -> None:
 41    domain, domain_name = get_domain_from_arg(args.domain)
 42    assert isinstance(domain, StateGoalVizable)
 43    states, goals = domain.get_start_goal_pairs([args.steps])
 44    state: State = states[0]
 45    goal: Goal = goals[0]
 46    fig = plt.figure(figsize=(5, 5))
 47    domain.visualize_state_goal(state, goal, fig)
 48    print(f"Goal Reached: {domain.is_solved([state], [goal])[0]}")
 49
 50    if isinstance(domain, StringToAct):
 51        plt.show(block=False)
 52        while True:
 53            act_str = input("Write action (make blank to quit): ")
 54            if len(act_str) == 0:
 55                break
 56            action: Optional[Action] = domain.string_to_action(act_str)
 57            if action is None:
 58                print(f"No action {act_str}")
 59            else:
 60                states_next, tcs = domain.next_state([state], [action])
 61                state = states_next[0]
 62                tc: float = tcs[0]
 63                fig.clear()
 64                cast(StateGoalVizable, domain).visualize_state_goal(state, goal, fig)
 65                print(f"Transition cost: {tc}")
 66                fig.canvas.draw()
 67
 68                print(f"Goal Reached: {domain.is_solved([state], [goal])[0]}")
 69    else:
 70        plt.show(block=True)
 71
 72
 73def main() -> None:
 74    parser = ArgumentParser(prog="deepxube", description="Solve pathfinding problems with deep reinforcement learning "
 75                                                         "and heuristic search.",
 76                            formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 77    subparsers = parser.add_subparsers(help="")
 78
 79    # train
 80    parser_tr: ArgumentParser = subparsers.add_parser('train', help="Train a heuristic function.",
 81                                                      formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 82    parser_train(parser_tr)
 83
 84    # domain info
 85    parser_domain_info: ArgumentParser = subparsers.add_parser('domain_info', help="Print information on domains that "
 86                                                                                   "deepxube has registered. "
 87                                                                                   "Put user-defined definitions of "
 88                                                                                   "domains in './domains/'")
 89    _parser_domain_info(parser_domain_info)
 90
 91    # heuristic info
 92    parser_heur_info: ArgumentParser = subparsers.add_parser('heuristic_info', help="Print information on neural network "
 93                                                                                    "representations of heuristic functions "
 94                                                                                    "that deepxube has registered. "
 95                                                                                    "Put user-defined definitions of "
 96                                                                                    "heuristic neural networks in "
 97                                                                                    "'./heuristics/'")
 98    _parser_heur_info(parser_heur_info)
 99
100    # visualization
101    parser_viz: ArgumentParser = subparsers.add_parser('viz', help="Visualize states/goals")
102    _parse_viz_info(parser_viz)
103
104    args = parser.parse_args()
105
106    args.func(args)
107
108
109def _parser_domain_info(parser: ArgumentParser) -> None:
110    parser.set_defaults(func=domain_info)
111
112
113def _parser_heur_info(parser: ArgumentParser) -> None:
114    parser.set_defaults(func=heur_info)
115
116
117def _parse_viz_info(parser: ArgumentParser) -> None:
118    parser.add_argument('--domain', type=str, required=True, help="Domain name and arguments.")
119    parser.add_argument('--steps', type=int, default=0, help="Number of steps to take to generate problem instnace.")
120    parser.set_defaults(func=viz)
def domain_info(args: argparse.Namespace) -> None:
19def domain_info(args: argparse.Namespace) -> None:
20    domain_names: List[str] = get_all_domain_names()
21    for domain_name in domain_names:
22        print(f"Domain: {domain_name}")
23        parser: Optional[DomainParser] = get_domain_parser(domain_name)
24        if parser is not None:
25            print(textwrap.indent("Parser: " + parser.help(), '\t'))
26
27        nnet_input_t_keys: List[Tuple[str, str]] = get_domain_nnet_input_keys(domain_name)
28        if len(nnet_input_t_keys) > 0:
29            print(f"\tNNet Inputs: {', '.join(nnet_input_t_key[1] for nnet_input_t_key in nnet_input_t_keys)}")
def heur_info(args: argparse.Namespace) -> None:
32def heur_info(args: argparse.Namespace) -> None:
33    heur_nnet_names: List[str] = get_all_heur_nnet_names()
34    for heur_nnet_name in heur_nnet_names:
35        print(f"Heur NNet: {heur_nnet_name}")
36        parser: Optional[HeurNNetParser] = get_heur_nnet_parser(heur_nnet_name)
37        if parser is not None:
38            print(textwrap.indent("Parser: " + parser.help(), '\t'))
def viz(args: argparse.Namespace) -> None:
41def viz(args: argparse.Namespace) -> None:
42    domain, domain_name = get_domain_from_arg(args.domain)
43    assert isinstance(domain, StateGoalVizable)
44    states, goals = domain.get_start_goal_pairs([args.steps])
45    state: State = states[0]
46    goal: Goal = goals[0]
47    fig = plt.figure(figsize=(5, 5))
48    domain.visualize_state_goal(state, goal, fig)
49    print(f"Goal Reached: {domain.is_solved([state], [goal])[0]}")
50
51    if isinstance(domain, StringToAct):
52        plt.show(block=False)
53        while True:
54            act_str = input("Write action (make blank to quit): ")
55            if len(act_str) == 0:
56                break
57            action: Optional[Action] = domain.string_to_action(act_str)
58            if action is None:
59                print(f"No action {act_str}")
60            else:
61                states_next, tcs = domain.next_state([state], [action])
62                state = states_next[0]
63                tc: float = tcs[0]
64                fig.clear()
65                cast(StateGoalVizable, domain).visualize_state_goal(state, goal, fig)
66                print(f"Transition cost: {tc}")
67                fig.canvas.draw()
68
69                print(f"Goal Reached: {domain.is_solved([state], [goal])[0]}")
70    else:
71        plt.show(block=True)
def main() -> None:
 74def main() -> None:
 75    parser = ArgumentParser(prog="deepxube", description="Solve pathfinding problems with deep reinforcement learning "
 76                                                         "and heuristic search.",
 77                            formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 78    subparsers = parser.add_subparsers(help="")
 79
 80    # train
 81    parser_tr: ArgumentParser = subparsers.add_parser('train', help="Train a heuristic function.",
 82                                                      formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 83    parser_train(parser_tr)
 84
 85    # domain info
 86    parser_domain_info: ArgumentParser = subparsers.add_parser('domain_info', help="Print information on domains that "
 87                                                                                   "deepxube has registered. "
 88                                                                                   "Put user-defined definitions of "
 89                                                                                   "domains in './domains/'")
 90    _parser_domain_info(parser_domain_info)
 91
 92    # heuristic info
 93    parser_heur_info: ArgumentParser = subparsers.add_parser('heuristic_info', help="Print information on neural network "
 94                                                                                    "representations of heuristic functions "
 95                                                                                    "that deepxube has registered. "
 96                                                                                    "Put user-defined definitions of "
 97                                                                                    "heuristic neural networks in "
 98                                                                                    "'./heuristics/'")
 99    _parser_heur_info(parser_heur_info)
100
101    # visualization
102    parser_viz: ArgumentParser = subparsers.add_parser('viz', help="Visualize states/goals")
103    _parse_viz_info(parser_viz)
104
105    args = parser.parse_args()
106
107    args.func(args)