deepxube.trainers.train_policy

Module Contents

Classes

API

class deepxube.trainers.train_policy.TrainPolicy(nnet: deepxube.base.trainer.NNet, updater: deepxube.base.trainer.Up, to_main_q: multiprocessing.Queue, from_main_qs: List[multiprocessing.Queue], nnet_file: str, nnet_targ_file: str, status_file: str, train_summary_file: str, device: torch.device, on_gpu: bool, writer: torch.utils.tensorboard.SummaryWriter, train_args: deepxube.base.trainer.TrainArgs)

Bases: deepxube.base.trainer.Train[deepxube.base.heuristic.PolicyNNet, deepxube.base.updater.UpdatePolicy]

static data_parallel() bool
_train_itr(batch: List[numpy.typing.NDArray], first_itr_in_update: bool, times: deepxube.utils.timing_utils.Times) float
_add_post_up_info() List[str]
_get_shapes_dtypes() List[Tuple[Tuple[int, ...], numpy.dtype]]