deepxube.nnet.nnet_utils
1from abc import ABC, abstractmethod 2from typing import List, Tuple, Optional, Any, Callable, TypeVar, Generic 3from dataclasses import dataclass 4 5from deepxube.utils.data_utils import SharedNDArray, np_to_shnd, combine_l_l 6import numpy as np 7from numpy.typing import NDArray 8import os 9import torch 10from torch import nn 11from collections import OrderedDict 12import re 13from torch import Tensor 14from torch.multiprocessing import Queue, get_context 15from multiprocessing.process import BaseProcess 16 17 18# training 19def to_pytorch_input(states_nnet: List[NDArray[Any]], device: torch.device) -> List[Tensor]: 20 states_nnet_tensors = [] 21 for tensor_np in states_nnet: 22 tensor = torch.tensor(tensor_np, device=device) 23 states_nnet_tensors.append(tensor) 24 25 return states_nnet_tensors 26 27 28# pytorch device 29def get_device() -> Tuple[torch.device, List[int], bool]: 30 device: torch.device = torch.device("cpu") 31 devices: List[int] = [] 32 on_gpu: bool = False 33 if ('CUDA_VISIBLE_DEVICES' in os.environ) and torch.cuda.is_available(): 34 device = torch.device("cuda:%i" % 0) 35 devices = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")] 36 on_gpu = True 37 torch.set_num_threads(1) 38 else: 39 torch.set_num_threads(8) 40 41 return device, devices, on_gpu 42 43 44# loading nnet 45def load_nnet(model_file: str, nnet: nn.Module, device: Optional[torch.device] = None) -> nn.Module: 46 # get state dict 47 if device is None: 48 state_dict = torch.load(model_file, weights_only=True) 49 else: 50 state_dict = torch.load(model_file, map_location=device, weights_only=False) 51 52 # remove module prefix 53 new_state_dict = OrderedDict() 54 for k, v in state_dict.items(): 55 k = re.sub(r'^module\.', '', k) 56 new_state_dict[k] = v 57 58 # set state dict 59 nnet.load_state_dict(new_state_dict) 60 61 nnet.eval() 62 63 return nnet 64 65 66def get_available_gpu_nums() -> List[int]: 67 gpu_nums: List[int] = [] 68 if ('CUDA_VISIBLE_DEVICES' in os.environ) and (len(os.environ['CUDA_VISIBLE_DEVICES']) > 0): 69 gpu_nums = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")] 70 71 return gpu_nums 72 73 74def nnet_batched(nnet: nn.Module, inputs: List[NDArray[Any]], batch_size: Optional[int], 75 device: torch.device) -> List[NDArray[np.float64]]: 76 outputs_l_l: List[List[NDArray[np.float64]]] = [] 77 78 num_states: int = inputs[0].shape[0] 79 80 batch_size_inst: int = num_states 81 if batch_size is not None: 82 batch_size_inst = batch_size 83 84 start_idx: int = 0 85 num_outputs: Optional[int] = None 86 while start_idx < num_states: 87 # get batch 88 end_idx: int = min(start_idx + batch_size_inst, num_states) 89 inputs_batch = [x[start_idx:end_idx] for x in inputs] 90 91 # get nnet output 92 inputs_batch_t = to_pytorch_input(inputs_batch, device) 93 94 outputs_batch_t_l: List[Tensor] = nnet(inputs_batch_t) 95 outputs_batch_l: List[NDArray[np.float64]] = [outputs_batch_t.cpu().data.numpy() 96 for outputs_batch_t in outputs_batch_t_l] 97 if num_outputs is None: 98 num_outputs = len(outputs_batch_l) 99 else: 100 assert len(outputs_batch_l) == num_outputs, f"{len(outputs_batch_l)} != {num_outputs}" 101 102 for out_idx in range(len(outputs_batch_l)): 103 outputs_batch_l[out_idx] = outputs_batch_l[out_idx].astype(np.float64) 104 outputs_l_l.append(outputs_batch_l) 105 106 start_idx = end_idx 107 108 outputs_l: List[NDArray[np.float64]] = combine_l_l(outputs_l_l, "concat") 109 for out_idx in range(len(outputs_l)): 110 assert (outputs_l[out_idx].shape[0] == num_states) 111 112 return outputs_l 113 114 115@dataclass 116class NNetParInfo: 117 nnet_i_q: Queue 118 nnet_o_q: Queue 119 proc_id: int 120 121 122def nnet_in_out_shared_q(nnet: nn.Module, inputs_nnet_shm: List[SharedNDArray], batch_size: Optional[int], 123 device: torch.device, nnet_o_q: Queue) -> None: 124 # get outputs 125 inputs_nnet: List[NDArray] = [] 126 for inputs_idx in range(len(inputs_nnet_shm)): 127 inputs_nnet.append(inputs_nnet_shm[inputs_idx].array) 128 129 outputs_l: List[NDArray[np.float64]] = nnet_batched(nnet, inputs_nnet, batch_size, device) 130 131 # send outputs 132 outputs_l_shm: List[SharedNDArray] = [np_to_shnd(outputs) for outputs in outputs_l] 133 nnet_o_q.put(outputs_l_shm) 134 135 for arr_shm in inputs_nnet_shm + outputs_l_shm: 136 arr_shm.close() 137 138 139# parallel neural networks 140def nnet_fn_runner(nnet_i_q: Queue, nnet_o_qs: List[Queue], model_file: str, device: torch.device, on_gpu: bool, 141 gpu_num: int, get_nnet: Callable[[], nn.Module], batch_size: Optional[int]) -> None: 142 if (gpu_num is not None) and on_gpu: 143 os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) 144 145 torch.set_num_threads(1) 146 nnet: nn.Module = get_nnet() 147 nnet = load_nnet(model_file, nnet, device=device) 148 nnet.eval() 149 nnet.to(device) 150 # if on_gpu: 151 # nnet = nn.DataParallel(nnet) 152 153 while True: 154 # get from input q 155 inputs_nnet_shm: Optional[List[SharedNDArray]] 156 proc_id, inputs_nnet_shm = nnet_i_q.get() 157 if proc_id is None: 158 break 159 160 # nnet in/out 161 nnet_in_out_shared_q(nnet, inputs_nnet_shm, batch_size, device, nnet_o_qs[proc_id]) 162 163 164def get_nnet_par_infos(num_procs: int) -> List[NNetParInfo]: 165 ctx = get_context("spawn") 166 167 nnet_i_q: Queue = ctx.Queue() 168 nnet_o_qs: List[Queue] = [] 169 nnet_par_infos: List[NNetParInfo] = [] 170 for proc_id in range(num_procs): 171 nnet_o_q: Queue = ctx.Queue(1) 172 nnet_o_qs.append(nnet_o_q) 173 nnet_par_infos.append(NNetParInfo(nnet_i_q, nnet_o_q, proc_id)) 174 175 return nnet_par_infos 176 177 178def start_nnet_fn_runners(get_nnet: Callable[[], nn.Module], nnet_par_infos: List[NNetParInfo], model_file: str, 179 device: torch.device, on_gpu: bool, 180 batch_size: Optional[int] = None) -> List[BaseProcess]: 181 ctx = get_context("spawn") 182 183 nnet_i_q: Queue = nnet_par_infos[0].nnet_i_q 184 nnet_o_qs: List[Queue] = [nnet_par_info.nnet_o_q for nnet_par_info in nnet_par_infos] 185 186 # initialize heuristic procs 187 if ('CUDA_VISIBLE_DEVICES' in os.environ) and (len(os.environ['CUDA_VISIBLE_DEVICES']) > 0): 188 gpu_nums = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")] 189 else: 190 gpu_nums = [-1] 191 192 nnet_procs: List[BaseProcess] = [] 193 for gpu_num in gpu_nums: 194 nnet_fn_procs = ctx.Process(target=nnet_fn_runner, 195 args=(nnet_i_q, nnet_o_qs, model_file, device, on_gpu, gpu_num, get_nnet, 196 batch_size)) 197 nnet_fn_procs.daemon = True 198 nnet_fn_procs.start() 199 nnet_procs.append(nnet_fn_procs) 200 201 return nnet_procs 202 203 204def stop_nnet_runners(nnet_fn_procs: List[BaseProcess], nnet_par_infos: List[NNetParInfo]) -> None: 205 for _ in nnet_fn_procs: 206 nnet_par_infos[0].nnet_i_q.put((None, None)) 207 208 for heur_fn_proc in nnet_fn_procs: 209 heur_fn_proc.join() 210 211 212NNetCallable = Callable[..., Any] 213NNetFn = TypeVar('NNetFn', bound=NNetCallable) 214 215 216class NNetPar(ABC, Generic[NNetFn]): 217 @abstractmethod 218 def get_nnet_fn(self, nnet: nn.Module, batch_size: Optional[int], device: torch.device, 219 update_num: Optional[int]) -> NNetFn: 220 pass 221 222 @abstractmethod 223 def get_nnet_par_fn(self, nnet_par_info: NNetParInfo, update_num: Optional[int]) -> NNetFn: 224 pass 225 226 @abstractmethod 227 def get_nnet(self) -> nn.Module: 228 pass 229 230 231def get_nnet_par_out(inputs_nnet: List[NDArray], nnet_par_info: NNetParInfo) -> List[NDArray]: 232 inputs_nnet_shm: List[SharedNDArray] = [np_to_shnd(inputs_nnet_i) 233 for input_idx, inputs_nnet_i in enumerate(inputs_nnet)] 234 235 nnet_par_info.nnet_i_q.put((nnet_par_info.proc_id, inputs_nnet_shm)) 236 237 out_shm_l: List[SharedNDArray] = nnet_par_info.nnet_o_q.get() 238 out_l: List[NDArray] = [out_shm.array.copy() for out_shm in out_shm_l] 239 240 for arr_shm in inputs_nnet_shm + out_shm_l: 241 arr_shm.close() 242 arr_shm.unlink() 243 244 return out_l
def
to_pytorch_input( states_nnet: List[numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]], device: torch.device) -> List[torch.Tensor]:
def
get_device() -> Tuple[torch.device, List[int], bool]:
30def get_device() -> Tuple[torch.device, List[int], bool]: 31 device: torch.device = torch.device("cpu") 32 devices: List[int] = [] 33 on_gpu: bool = False 34 if ('CUDA_VISIBLE_DEVICES' in os.environ) and torch.cuda.is_available(): 35 device = torch.device("cuda:%i" % 0) 36 devices = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")] 37 on_gpu = True 38 torch.set_num_threads(1) 39 else: 40 torch.set_num_threads(8) 41 42 return device, devices, on_gpu
def
load_nnet( model_file: str, nnet: torch.nn.modules.module.Module, device: Optional[torch.device] = None) -> torch.nn.modules.module.Module:
46def load_nnet(model_file: str, nnet: nn.Module, device: Optional[torch.device] = None) -> nn.Module: 47 # get state dict 48 if device is None: 49 state_dict = torch.load(model_file, weights_only=True) 50 else: 51 state_dict = torch.load(model_file, map_location=device, weights_only=False) 52 53 # remove module prefix 54 new_state_dict = OrderedDict() 55 for k, v in state_dict.items(): 56 k = re.sub(r'^module\.', '', k) 57 new_state_dict[k] = v 58 59 # set state dict 60 nnet.load_state_dict(new_state_dict) 61 62 nnet.eval() 63 64 return nnet
def
get_available_gpu_nums() -> List[int]:
def
nnet_batched( nnet: torch.nn.modules.module.Module, inputs: List[numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]], batch_size: Optional[int], device: torch.device) -> List[numpy.ndarray[tuple[int, ...], numpy.dtype[numpy.float64]]]:
75def nnet_batched(nnet: nn.Module, inputs: List[NDArray[Any]], batch_size: Optional[int], 76 device: torch.device) -> List[NDArray[np.float64]]: 77 outputs_l_l: List[List[NDArray[np.float64]]] = [] 78 79 num_states: int = inputs[0].shape[0] 80 81 batch_size_inst: int = num_states 82 if batch_size is not None: 83 batch_size_inst = batch_size 84 85 start_idx: int = 0 86 num_outputs: Optional[int] = None 87 while start_idx < num_states: 88 # get batch 89 end_idx: int = min(start_idx + batch_size_inst, num_states) 90 inputs_batch = [x[start_idx:end_idx] for x in inputs] 91 92 # get nnet output 93 inputs_batch_t = to_pytorch_input(inputs_batch, device) 94 95 outputs_batch_t_l: List[Tensor] = nnet(inputs_batch_t) 96 outputs_batch_l: List[NDArray[np.float64]] = [outputs_batch_t.cpu().data.numpy() 97 for outputs_batch_t in outputs_batch_t_l] 98 if num_outputs is None: 99 num_outputs = len(outputs_batch_l) 100 else: 101 assert len(outputs_batch_l) == num_outputs, f"{len(outputs_batch_l)} != {num_outputs}" 102 103 for out_idx in range(len(outputs_batch_l)): 104 outputs_batch_l[out_idx] = outputs_batch_l[out_idx].astype(np.float64) 105 outputs_l_l.append(outputs_batch_l) 106 107 start_idx = end_idx 108 109 outputs_l: List[NDArray[np.float64]] = combine_l_l(outputs_l_l, "concat") 110 for out_idx in range(len(outputs_l)): 111 assert (outputs_l[out_idx].shape[0] == num_states) 112 113 return outputs_l
@dataclass
class
NNetParInfo:
NNetParInfo( nnet_i_q: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>, nnet_o_q: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>, proc_id: int)
nnet_i_q: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object at 0x1072559c0>>
def
nnet_fn_runner( nnet_i_q: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>, nnet_o_qs: List[<bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>], model_file: str, device: torch.device, on_gpu: bool, gpu_num: int, get_nnet: Callable[[], torch.nn.modules.module.Module], batch_size: Optional[int]) -> None:
141def nnet_fn_runner(nnet_i_q: Queue, nnet_o_qs: List[Queue], model_file: str, device: torch.device, on_gpu: bool, 142 gpu_num: int, get_nnet: Callable[[], nn.Module], batch_size: Optional[int]) -> None: 143 if (gpu_num is not None) and on_gpu: 144 os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) 145 146 torch.set_num_threads(1) 147 nnet: nn.Module = get_nnet() 148 nnet = load_nnet(model_file, nnet, device=device) 149 nnet.eval() 150 nnet.to(device) 151 # if on_gpu: 152 # nnet = nn.DataParallel(nnet) 153 154 while True: 155 # get from input q 156 inputs_nnet_shm: Optional[List[SharedNDArray]] 157 proc_id, inputs_nnet_shm = nnet_i_q.get() 158 if proc_id is None: 159 break 160 161 # nnet in/out 162 nnet_in_out_shared_q(nnet, inputs_nnet_shm, batch_size, device, nnet_o_qs[proc_id])
165def get_nnet_par_infos(num_procs: int) -> List[NNetParInfo]: 166 ctx = get_context("spawn") 167 168 nnet_i_q: Queue = ctx.Queue() 169 nnet_o_qs: List[Queue] = [] 170 nnet_par_infos: List[NNetParInfo] = [] 171 for proc_id in range(num_procs): 172 nnet_o_q: Queue = ctx.Queue(1) 173 nnet_o_qs.append(nnet_o_q) 174 nnet_par_infos.append(NNetParInfo(nnet_i_q, nnet_o_q, proc_id)) 175 176 return nnet_par_infos
def
start_nnet_fn_runners( get_nnet: Callable[[], torch.nn.modules.module.Module], nnet_par_infos: List[NNetParInfo], model_file: str, device: torch.device, on_gpu: bool, batch_size: Optional[int] = None) -> List[multiprocessing.process.BaseProcess]:
179def start_nnet_fn_runners(get_nnet: Callable[[], nn.Module], nnet_par_infos: List[NNetParInfo], model_file: str, 180 device: torch.device, on_gpu: bool, 181 batch_size: Optional[int] = None) -> List[BaseProcess]: 182 ctx = get_context("spawn") 183 184 nnet_i_q: Queue = nnet_par_infos[0].nnet_i_q 185 nnet_o_qs: List[Queue] = [nnet_par_info.nnet_o_q for nnet_par_info in nnet_par_infos] 186 187 # initialize heuristic procs 188 if ('CUDA_VISIBLE_DEVICES' in os.environ) and (len(os.environ['CUDA_VISIBLE_DEVICES']) > 0): 189 gpu_nums = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")] 190 else: 191 gpu_nums = [-1] 192 193 nnet_procs: List[BaseProcess] = [] 194 for gpu_num in gpu_nums: 195 nnet_fn_procs = ctx.Process(target=nnet_fn_runner, 196 args=(nnet_i_q, nnet_o_qs, model_file, device, on_gpu, gpu_num, get_nnet, 197 batch_size)) 198 nnet_fn_procs.daemon = True 199 nnet_fn_procs.start() 200 nnet_procs.append(nnet_fn_procs) 201 202 return nnet_procs
def
stop_nnet_runners( nnet_fn_procs: List[multiprocessing.process.BaseProcess], nnet_par_infos: List[NNetParInfo]) -> None:
NNetCallable =
typing.Callable[..., typing.Any]
class
NNetPar(abc.ABC, typing.Generic[~NNetFn]):
217class NNetPar(ABC, Generic[NNetFn]): 218 @abstractmethod 219 def get_nnet_fn(self, nnet: nn.Module, batch_size: Optional[int], device: torch.device, 220 update_num: Optional[int]) -> NNetFn: 221 pass 222 223 @abstractmethod 224 def get_nnet_par_fn(self, nnet_par_info: NNetParInfo, update_num: Optional[int]) -> NNetFn: 225 pass 226 227 @abstractmethod 228 def get_nnet(self) -> nn.Module: 229 pass
Helper class that provides a standard way to create an ABC using inheritance.
@abstractmethod
def
get_nnet_fn( self, nnet: torch.nn.modules.module.Module, batch_size: Optional[int], device: torch.device, update_num: Optional[int]) -> ~NNetFn:
@abstractmethod
def
get_nnet_par_fn( self, nnet_par_info: NNetParInfo, update_num: Optional[int]) -> ~NNetFn:
def
get_nnet_par_out( inputs_nnet: List[numpy.ndarray[tuple[int, ...], numpy.dtype[+_ScalarType_co]]], nnet_par_info: NNetParInfo) -> List[numpy.ndarray[tuple[int, ...], numpy.dtype[+_ScalarType_co]]]:
232def get_nnet_par_out(inputs_nnet: List[NDArray], nnet_par_info: NNetParInfo) -> List[NDArray]: 233 inputs_nnet_shm: List[SharedNDArray] = [np_to_shnd(inputs_nnet_i) 234 for input_idx, inputs_nnet_i in enumerate(inputs_nnet)] 235 236 nnet_par_info.nnet_i_q.put((nnet_par_info.proc_id, inputs_nnet_shm)) 237 238 out_shm_l: List[SharedNDArray] = nnet_par_info.nnet_o_q.get() 239 out_l: List[NDArray] = [out_shm.array.copy() for out_shm in out_shm_l] 240 241 for arr_shm in inputs_nnet_shm + out_shm_l: 242 arr_shm.close() 243 arr_shm.unlink() 244 245 return out_l