deepxube.base.trainer¶
Module Contents¶
Classes¶
|
|
Functions¶
Data¶
API¶
- class deepxube.base.trainer.TrainArgs¶
- Parameters:
batch_size – Batch size
max_itrs – Maximum number of iterations
balance_steps – If true, steps are balanced based on solve percentage
rb – amount of data generated from previous updates to keep in replay buffer. Total replay buffer size will
then be train_args.batch_size * up_args.up_gen_itrs * rb. :param loss_thresh: Loss threshold for updating. :param targ_up_searches: If > 0, do a greedy search with updater for minimum given number of searches to test if target network should be updated. Otherwise, it will be updated automatically. :param skip_heur: Skip training of heuristic function :param skip_policy: Skip training of policy :param checkpoint: Save checkpoint file of network being trained at initialization and at every given number of update checks. Checkpoint number given is training iteration, not update number. If 0 then checkpointing is not done. :param grad_accum: Number of times to split batch into sub-batches for gradient accumulation :param display: Number of iterations to display progress when training nnet. No display if 0.
- batch_size: int = None¶
- max_itrs: int = None¶
- balance_steps: bool = None¶
- rb: int = 0¶
- loss_thresh: float = None¶
- targ_up_searches: int = 0¶
- skip_heur: bool = False¶
- skip_policy: bool = False¶
- checkpoint: int = 0¶
- grad_accum: int = 1¶
- display: int = 100¶
- class deepxube.base.trainer.DataBuffer(max_size: int, shapes: List[Tuple[int, ...]], dtypes: List[numpy.dtype])¶
Initialization
- add(arrays_add: List[numpy.typing.NDArray]) None¶
- sample(sel_idxs: numpy.typing.NDArray) List[numpy.typing.NDArray]¶
- size() int¶
- clear() None¶
- _add_circular(arrays_add: List[numpy.typing.NDArray]) None¶
- class deepxube.base.trainer.Status(step_max: int, balance_steps: bool)¶
Initialization
- update_step_probs(step_to_search_perf: Dict[int, deepxube.pathfinding.utils.performance.PathFindPerf]) None¶
- class deepxube.base.trainer.TrainSummary¶
Initialization
- update_pathfindstats(step_to_pathfindperf: Dict[int, deepxube.pathfinding.utils.performance.PathFindPerf], itr: int) None¶
- deepxube.base.trainer.NNet = 'TypeVar(...)'¶
- deepxube.base.trainer.Up = 'TypeVar(...)'¶
- deepxube.base.trainer.update_optimizer(optimizer: torch.optim.Optimizer, nnet: Union[torch.nn.DataParallel, deepxube.base.heuristic.DeepXubeNNet], train_itr: int) None¶
- class deepxube.base.trainer.Train(nnet: deepxube.base.trainer.NNet, updater: deepxube.base.trainer.Up, to_main_q: multiprocessing.Queue, from_main_qs: List[multiprocessing.Queue], nnet_file: str, nnet_targ_file: str, status_file: str, train_summary_file: str, device: torch.device, on_gpu: bool, writer: torch.utils.tensorboard.SummaryWriter, train_args: deepxube.base.trainer.TrainArgs)¶
Bases:
typing.Generic[deepxube.base.trainer.NNet,deepxube.base.trainer.Up],abc.ABC- abstract static data_parallel() bool¶
- update_step() None¶
- _get_update_data(num_gen: int, times: deepxube.utils.timing_utils.Times) None¶
- _train(times: deepxube.utils.timing_utils.Times) float¶
- _train_sync_main(num_gen: int, times: deepxube.utils.timing_utils.Times) float¶
- abstract _train_itr(batch: List[numpy.typing.NDArray], first_itr_in_update: bool, times: deepxube.utils.timing_utils.Times) float¶
- _end_update(itr_init: int, times: deepxube.utils.timing_utils.Times) None¶
- _save_checkpoint() None¶
- abstract _add_post_up_info() List[str]¶
- abstract _get_shapes_dtypes() List[Tuple[Tuple[int, ...], numpy.dtype]]¶