deepxube.trainers.train_policy¶
Module Contents¶
Classes¶
API¶
- class deepxube.trainers.train_policy.TrainPolicy(nnet: deepxube.base.trainer.NNet, updater: deepxube.base.trainer.Up, to_main_q: multiprocessing.Queue, from_main_qs: List[multiprocessing.Queue], nnet_file: str, nnet_targ_file: str, status_file: str, train_summary_file: str, device: torch.device, on_gpu: bool, writer: torch.utils.tensorboard.SummaryWriter, train_args: deepxube.base.trainer.TrainArgs)¶
Bases:
deepxube.base.trainer.Train[deepxube.base.heuristic.PolicyNNet,deepxube.base.updater.UpdatePolicy]- static data_parallel() bool¶
- _train_itr(batch: List[numpy.typing.NDArray], first_itr_in_update: bool, times: deepxube.utils.timing_utils.Times) float¶
- _add_post_up_info() List[str]¶
- _get_shapes_dtypes() List[Tuple[Tuple[int, ...], numpy.dtype]]¶