deepxube.base.trainer

Module Contents

Classes

TrainArgs

param batch_size:

Batch size

param max_itrs:

Maximum number of iterations

param balance_steps:

If true, steps are balanced based on solve percentage

param rb:

amount of data generated from previous updates to keep in replay buffer. Total replay buffer size will

DataBuffer

Status

TrainSummary

Train

Functions

Data

API

class deepxube.base.trainer.TrainArgs
Parameters:
  • batch_size – Batch size

  • max_itrs – Maximum number of iterations

  • balance_steps – If true, steps are balanced based on solve percentage

  • rb – amount of data generated from previous updates to keep in replay buffer. Total replay buffer size will

then be train_args.batch_size * up_args.up_gen_itrs * rb. :param loss_thresh: Loss threshold for updating. :param targ_up_searches: If > 0, do a greedy search with updater for minimum given number of searches to test if target network should be updated. Otherwise, it will be updated automatically. :param skip_heur: Skip training of heuristic function :param skip_policy: Skip training of policy :param checkpoint: Save checkpoint file of network being trained at initialization and at every given number of update checks. Checkpoint number given is training iteration, not update number. If 0 then checkpointing is not done. :param grad_accum: Number of times to split batch into sub-batches for gradient accumulation :param display: Number of iterations to display progress when training nnet. No display if 0.

batch_size: int = None
max_itrs: int = None
balance_steps: bool = None
rb: int = 0
loss_thresh: float = None
targ_up_searches: int = 0
skip_heur: bool = False
skip_policy: bool = False
checkpoint: int = 0
grad_accum: int = 1
display: int = 100
class deepxube.base.trainer.DataBuffer(max_size: int, shapes: List[Tuple[int, ...]], dtypes: List[numpy.dtype])

Initialization

add(arrays_add: List[numpy.typing.NDArray]) None
sample(sel_idxs: numpy.typing.NDArray) List[numpy.typing.NDArray]
size() int
clear() None
_add_circular(arrays_add: List[numpy.typing.NDArray]) None
class deepxube.base.trainer.Status(step_max: int, balance_steps: bool)

Initialization

update_step_probs(step_to_search_perf: Dict[int, deepxube.pathfinding.utils.performance.PathFindPerf]) None
class deepxube.base.trainer.TrainSummary

Initialization

update_pathfindstats(step_to_pathfindperf: Dict[int, deepxube.pathfinding.utils.performance.PathFindPerf], itr: int) None
deepxube.base.trainer.NNet = 'TypeVar(...)'
deepxube.base.trainer.Up = 'TypeVar(...)'
deepxube.base.trainer.update_optimizer(optimizer: torch.optim.Optimizer, nnet: Union[torch.nn.DataParallel, deepxube.base.heuristic.DeepXubeNNet], train_itr: int) None
class deepxube.base.trainer.Train(nnet: deepxube.base.trainer.NNet, updater: deepxube.base.trainer.Up, to_main_q: multiprocessing.Queue, from_main_qs: List[multiprocessing.Queue], nnet_file: str, nnet_targ_file: str, status_file: str, train_summary_file: str, device: torch.device, on_gpu: bool, writer: torch.utils.tensorboard.SummaryWriter, train_args: deepxube.base.trainer.TrainArgs)

Bases: typing.Generic[deepxube.base.trainer.NNet, deepxube.base.trainer.Up], abc.ABC

abstract static data_parallel() bool
update_step() None
_get_update_data(num_gen: int, times: deepxube.utils.timing_utils.Times) None
_train(times: deepxube.utils.timing_utils.Times) float
_train_sync_main(num_gen: int, times: deepxube.utils.timing_utils.Times) float
abstract _train_itr(batch: List[numpy.typing.NDArray], first_itr_in_update: bool, times: deepxube.utils.timing_utils.Times) float
_end_update(itr_init: int, times: deepxube.utils.timing_utils.Times) None
_save_checkpoint() None
abstract _add_post_up_info() List[str]
abstract _get_shapes_dtypes() List[Tuple[Tuple[int, ...], numpy.dtype]]