deepxube.nnet.nnet_utils

  1from abc import ABC, abstractmethod
  2from typing import List, Tuple, Optional, Any, Callable, TypeVar, Generic
  3from dataclasses import dataclass
  4
  5from deepxube.utils.data_utils import SharedNDArray, np_to_shnd, combine_l_l
  6import numpy as np
  7from numpy.typing import NDArray
  8import os
  9import torch
 10from torch import nn
 11from collections import OrderedDict
 12import re
 13from torch import Tensor
 14from torch.multiprocessing import Queue, get_context
 15from multiprocessing.process import BaseProcess
 16
 17
 18# training
 19def to_pytorch_input(states_nnet: List[NDArray[Any]], device: torch.device) -> List[Tensor]:
 20    states_nnet_tensors = []
 21    for tensor_np in states_nnet:
 22        tensor = torch.tensor(tensor_np, device=device)
 23        states_nnet_tensors.append(tensor)
 24
 25    return states_nnet_tensors
 26
 27
 28# pytorch device
 29def get_device() -> Tuple[torch.device, List[int], bool]:
 30    device: torch.device = torch.device("cpu")
 31    devices: List[int] = []
 32    on_gpu: bool = False
 33    if ('CUDA_VISIBLE_DEVICES' in os.environ) and torch.cuda.is_available():
 34        device = torch.device("cuda:%i" % 0)
 35        devices = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")]
 36        on_gpu = True
 37        torch.set_num_threads(1)
 38    else:
 39        torch.set_num_threads(8)
 40
 41    return device, devices, on_gpu
 42
 43
 44# loading nnet
 45def load_nnet(model_file: str, nnet: nn.Module, device: Optional[torch.device] = None) -> nn.Module:
 46    # get state dict
 47    if device is None:
 48        state_dict = torch.load(model_file, weights_only=True)
 49    else:
 50        state_dict = torch.load(model_file, map_location=device, weights_only=False)
 51
 52    # remove module prefix
 53    new_state_dict = OrderedDict()
 54    for k, v in state_dict.items():
 55        k = re.sub(r'^module\.', '', k)
 56        new_state_dict[k] = v
 57
 58    # set state dict
 59    nnet.load_state_dict(new_state_dict)
 60
 61    nnet.eval()
 62
 63    return nnet
 64
 65
 66def get_available_gpu_nums() -> List[int]:
 67    gpu_nums: List[int] = []
 68    if ('CUDA_VISIBLE_DEVICES' in os.environ) and (len(os.environ['CUDA_VISIBLE_DEVICES']) > 0):
 69        gpu_nums = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")]
 70
 71    return gpu_nums
 72
 73
 74def nnet_batched(nnet: nn.Module, inputs: List[NDArray[Any]], batch_size: Optional[int],
 75                 device: torch.device) -> List[NDArray[np.float64]]:
 76    outputs_l_l: List[List[NDArray[np.float64]]] = []
 77
 78    num_states: int = inputs[0].shape[0]
 79
 80    batch_size_inst: int = num_states
 81    if batch_size is not None:
 82        batch_size_inst = batch_size
 83
 84    start_idx: int = 0
 85    num_outputs: Optional[int] = None
 86    while start_idx < num_states:
 87        # get batch
 88        end_idx: int = min(start_idx + batch_size_inst, num_states)
 89        inputs_batch = [x[start_idx:end_idx] for x in inputs]
 90
 91        # get nnet output
 92        inputs_batch_t = to_pytorch_input(inputs_batch, device)
 93
 94        outputs_batch_t_l: List[Tensor] = nnet(inputs_batch_t)
 95        outputs_batch_l: List[NDArray[np.float64]] = [outputs_batch_t.cpu().data.numpy()
 96                                                      for outputs_batch_t in outputs_batch_t_l]
 97        if num_outputs is None:
 98            num_outputs = len(outputs_batch_l)
 99        else:
100            assert len(outputs_batch_l) == num_outputs, f"{len(outputs_batch_l)} != {num_outputs}"
101
102        for out_idx in range(len(outputs_batch_l)):
103            outputs_batch_l[out_idx] = outputs_batch_l[out_idx].astype(np.float64)
104        outputs_l_l.append(outputs_batch_l)
105
106        start_idx = end_idx
107
108    outputs_l: List[NDArray[np.float64]] = combine_l_l(outputs_l_l, "concat")
109    for out_idx in range(len(outputs_l)):
110        assert (outputs_l[out_idx].shape[0] == num_states)
111
112    return outputs_l
113
114
115@dataclass
116class NNetParInfo:
117    nnet_i_q: Queue
118    nnet_o_q: Queue
119    proc_id: int
120
121
122def nnet_in_out_shared_q(nnet: nn.Module, inputs_nnet_shm: List[SharedNDArray], batch_size: Optional[int],
123                         device: torch.device, nnet_o_q: Queue) -> None:
124    # get outputs
125    inputs_nnet: List[NDArray] = []
126    for inputs_idx in range(len(inputs_nnet_shm)):
127        inputs_nnet.append(inputs_nnet_shm[inputs_idx].array)
128
129    outputs_l: List[NDArray[np.float64]] = nnet_batched(nnet, inputs_nnet, batch_size, device)
130
131    # send outputs
132    outputs_l_shm: List[SharedNDArray] = [np_to_shnd(outputs) for outputs in outputs_l]
133    nnet_o_q.put(outputs_l_shm)
134
135    for arr_shm in inputs_nnet_shm + outputs_l_shm:
136        arr_shm.close()
137
138
139# parallel neural networks
140def nnet_fn_runner(nnet_i_q: Queue, nnet_o_qs: List[Queue], model_file: str, device: torch.device, on_gpu: bool,
141                   gpu_num: int, get_nnet: Callable[[], nn.Module], batch_size: Optional[int]) -> None:
142    if (gpu_num is not None) and on_gpu:
143        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num)
144
145    torch.set_num_threads(1)
146    nnet: nn.Module = get_nnet()
147    nnet = load_nnet(model_file, nnet, device=device)
148    nnet.eval()
149    nnet.to(device)
150    # if on_gpu:
151    #    nnet = nn.DataParallel(nnet)
152
153    while True:
154        # get from input q
155        inputs_nnet_shm: Optional[List[SharedNDArray]]
156        proc_id, inputs_nnet_shm = nnet_i_q.get()
157        if proc_id is None:
158            break
159
160        # nnet in/out
161        nnet_in_out_shared_q(nnet, inputs_nnet_shm, batch_size, device, nnet_o_qs[proc_id])
162
163
164def get_nnet_par_infos(num_procs: int) -> List[NNetParInfo]:
165    ctx = get_context("spawn")
166
167    nnet_i_q: Queue = ctx.Queue()
168    nnet_o_qs: List[Queue] = []
169    nnet_par_infos: List[NNetParInfo] = []
170    for proc_id in range(num_procs):
171        nnet_o_q: Queue = ctx.Queue(1)
172        nnet_o_qs.append(nnet_o_q)
173        nnet_par_infos.append(NNetParInfo(nnet_i_q, nnet_o_q, proc_id))
174
175    return nnet_par_infos
176
177
178def start_nnet_fn_runners(get_nnet: Callable[[], nn.Module], nnet_par_infos: List[NNetParInfo], model_file: str,
179                          device: torch.device, on_gpu: bool,
180                          batch_size: Optional[int] = None) -> List[BaseProcess]:
181    ctx = get_context("spawn")
182
183    nnet_i_q: Queue = nnet_par_infos[0].nnet_i_q
184    nnet_o_qs: List[Queue] = [nnet_par_info.nnet_o_q for nnet_par_info in nnet_par_infos]
185
186    # initialize heuristic procs
187    if ('CUDA_VISIBLE_DEVICES' in os.environ) and (len(os.environ['CUDA_VISIBLE_DEVICES']) > 0):
188        gpu_nums = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")]
189    else:
190        gpu_nums = [-1]
191
192    nnet_procs: List[BaseProcess] = []
193    for gpu_num in gpu_nums:
194        nnet_fn_procs = ctx.Process(target=nnet_fn_runner,
195                                    args=(nnet_i_q, nnet_o_qs, model_file, device, on_gpu, gpu_num, get_nnet,
196                                          batch_size))
197        nnet_fn_procs.daemon = True
198        nnet_fn_procs.start()
199        nnet_procs.append(nnet_fn_procs)
200
201    return nnet_procs
202
203
204def stop_nnet_runners(nnet_fn_procs: List[BaseProcess], nnet_par_infos: List[NNetParInfo]) -> None:
205    for _ in nnet_fn_procs:
206        nnet_par_infos[0].nnet_i_q.put((None, None))
207
208    for heur_fn_proc in nnet_fn_procs:
209        heur_fn_proc.join()
210
211
212NNetCallable = Callable[..., Any]
213NNetFn = TypeVar('NNetFn', bound=NNetCallable)
214
215
216class NNetPar(ABC, Generic[NNetFn]):
217    @abstractmethod
218    def get_nnet_fn(self, nnet: nn.Module, batch_size: Optional[int], device: torch.device,
219                    update_num: Optional[int]) -> NNetFn:
220        pass
221
222    @abstractmethod
223    def get_nnet_par_fn(self, nnet_par_info: NNetParInfo, update_num: Optional[int]) -> NNetFn:
224        pass
225
226    @abstractmethod
227    def get_nnet(self) -> nn.Module:
228        pass
229
230
231def get_nnet_par_out(inputs_nnet: List[NDArray], nnet_par_info: NNetParInfo) -> List[NDArray]:
232    inputs_nnet_shm: List[SharedNDArray] = [np_to_shnd(inputs_nnet_i)
233                                            for input_idx, inputs_nnet_i in enumerate(inputs_nnet)]
234
235    nnet_par_info.nnet_i_q.put((nnet_par_info.proc_id, inputs_nnet_shm))
236
237    out_shm_l: List[SharedNDArray] = nnet_par_info.nnet_o_q.get()
238    out_l: List[NDArray] = [out_shm.array.copy() for out_shm in out_shm_l]
239
240    for arr_shm in inputs_nnet_shm + out_shm_l:
241        arr_shm.close()
242        arr_shm.unlink()
243
244    return out_l
def to_pytorch_input( states_nnet: List[numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]], device: torch.device) -> List[torch.Tensor]:
20def to_pytorch_input(states_nnet: List[NDArray[Any]], device: torch.device) -> List[Tensor]:
21    states_nnet_tensors = []
22    for tensor_np in states_nnet:
23        tensor = torch.tensor(tensor_np, device=device)
24        states_nnet_tensors.append(tensor)
25
26    return states_nnet_tensors
def get_device() -> Tuple[torch.device, List[int], bool]:
30def get_device() -> Tuple[torch.device, List[int], bool]:
31    device: torch.device = torch.device("cpu")
32    devices: List[int] = []
33    on_gpu: bool = False
34    if ('CUDA_VISIBLE_DEVICES' in os.environ) and torch.cuda.is_available():
35        device = torch.device("cuda:%i" % 0)
36        devices = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")]
37        on_gpu = True
38        torch.set_num_threads(1)
39    else:
40        torch.set_num_threads(8)
41
42    return device, devices, on_gpu
def load_nnet( model_file: str, nnet: torch.nn.modules.module.Module, device: Optional[torch.device] = None) -> torch.nn.modules.module.Module:
46def load_nnet(model_file: str, nnet: nn.Module, device: Optional[torch.device] = None) -> nn.Module:
47    # get state dict
48    if device is None:
49        state_dict = torch.load(model_file, weights_only=True)
50    else:
51        state_dict = torch.load(model_file, map_location=device, weights_only=False)
52
53    # remove module prefix
54    new_state_dict = OrderedDict()
55    for k, v in state_dict.items():
56        k = re.sub(r'^module\.', '', k)
57        new_state_dict[k] = v
58
59    # set state dict
60    nnet.load_state_dict(new_state_dict)
61
62    nnet.eval()
63
64    return nnet
def get_available_gpu_nums() -> List[int]:
67def get_available_gpu_nums() -> List[int]:
68    gpu_nums: List[int] = []
69    if ('CUDA_VISIBLE_DEVICES' in os.environ) and (len(os.environ['CUDA_VISIBLE_DEVICES']) > 0):
70        gpu_nums = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")]
71
72    return gpu_nums
def nnet_batched( nnet: torch.nn.modules.module.Module, inputs: List[numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]], batch_size: Optional[int], device: torch.device) -> List[numpy.ndarray[tuple[int, ...], numpy.dtype[numpy.float64]]]:
 75def nnet_batched(nnet: nn.Module, inputs: List[NDArray[Any]], batch_size: Optional[int],
 76                 device: torch.device) -> List[NDArray[np.float64]]:
 77    outputs_l_l: List[List[NDArray[np.float64]]] = []
 78
 79    num_states: int = inputs[0].shape[0]
 80
 81    batch_size_inst: int = num_states
 82    if batch_size is not None:
 83        batch_size_inst = batch_size
 84
 85    start_idx: int = 0
 86    num_outputs: Optional[int] = None
 87    while start_idx < num_states:
 88        # get batch
 89        end_idx: int = min(start_idx + batch_size_inst, num_states)
 90        inputs_batch = [x[start_idx:end_idx] for x in inputs]
 91
 92        # get nnet output
 93        inputs_batch_t = to_pytorch_input(inputs_batch, device)
 94
 95        outputs_batch_t_l: List[Tensor] = nnet(inputs_batch_t)
 96        outputs_batch_l: List[NDArray[np.float64]] = [outputs_batch_t.cpu().data.numpy()
 97                                                      for outputs_batch_t in outputs_batch_t_l]
 98        if num_outputs is None:
 99            num_outputs = len(outputs_batch_l)
100        else:
101            assert len(outputs_batch_l) == num_outputs, f"{len(outputs_batch_l)} != {num_outputs}"
102
103        for out_idx in range(len(outputs_batch_l)):
104            outputs_batch_l[out_idx] = outputs_batch_l[out_idx].astype(np.float64)
105        outputs_l_l.append(outputs_batch_l)
106
107        start_idx = end_idx
108
109    outputs_l: List[NDArray[np.float64]] = combine_l_l(outputs_l_l, "concat")
110    for out_idx in range(len(outputs_l)):
111        assert (outputs_l[out_idx].shape[0] == num_states)
112
113    return outputs_l
@dataclass
class NNetParInfo:
116@dataclass
117class NNetParInfo:
118    nnet_i_q: Queue
119    nnet_o_q: Queue
120    proc_id: int
NNetParInfo( nnet_i_q: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>, nnet_o_q: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>, proc_id: int)
nnet_i_q: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object at 0x1072559c0>>
nnet_o_q: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object at 0x1072559c0>>
proc_id: int
def nnet_in_out_shared_q( nnet: torch.nn.modules.module.Module, inputs_nnet_shm: List[deepxube.utils.data_utils.SharedNDArray], batch_size: Optional[int], device: torch.device, nnet_o_q: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>) -> None:
123def nnet_in_out_shared_q(nnet: nn.Module, inputs_nnet_shm: List[SharedNDArray], batch_size: Optional[int],
124                         device: torch.device, nnet_o_q: Queue) -> None:
125    # get outputs
126    inputs_nnet: List[NDArray] = []
127    for inputs_idx in range(len(inputs_nnet_shm)):
128        inputs_nnet.append(inputs_nnet_shm[inputs_idx].array)
129
130    outputs_l: List[NDArray[np.float64]] = nnet_batched(nnet, inputs_nnet, batch_size, device)
131
132    # send outputs
133    outputs_l_shm: List[SharedNDArray] = [np_to_shnd(outputs) for outputs in outputs_l]
134    nnet_o_q.put(outputs_l_shm)
135
136    for arr_shm in inputs_nnet_shm + outputs_l_shm:
137        arr_shm.close()
def nnet_fn_runner( nnet_i_q: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>, nnet_o_qs: List[<bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>], model_file: str, device: torch.device, on_gpu: bool, gpu_num: int, get_nnet: Callable[[], torch.nn.modules.module.Module], batch_size: Optional[int]) -> None:
141def nnet_fn_runner(nnet_i_q: Queue, nnet_o_qs: List[Queue], model_file: str, device: torch.device, on_gpu: bool,
142                   gpu_num: int, get_nnet: Callable[[], nn.Module], batch_size: Optional[int]) -> None:
143    if (gpu_num is not None) and on_gpu:
144        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num)
145
146    torch.set_num_threads(1)
147    nnet: nn.Module = get_nnet()
148    nnet = load_nnet(model_file, nnet, device=device)
149    nnet.eval()
150    nnet.to(device)
151    # if on_gpu:
152    #    nnet = nn.DataParallel(nnet)
153
154    while True:
155        # get from input q
156        inputs_nnet_shm: Optional[List[SharedNDArray]]
157        proc_id, inputs_nnet_shm = nnet_i_q.get()
158        if proc_id is None:
159            break
160
161        # nnet in/out
162        nnet_in_out_shared_q(nnet, inputs_nnet_shm, batch_size, device, nnet_o_qs[proc_id])
def get_nnet_par_infos(num_procs: int) -> List[NNetParInfo]:
165def get_nnet_par_infos(num_procs: int) -> List[NNetParInfo]:
166    ctx = get_context("spawn")
167
168    nnet_i_q: Queue = ctx.Queue()
169    nnet_o_qs: List[Queue] = []
170    nnet_par_infos: List[NNetParInfo] = []
171    for proc_id in range(num_procs):
172        nnet_o_q: Queue = ctx.Queue(1)
173        nnet_o_qs.append(nnet_o_q)
174        nnet_par_infos.append(NNetParInfo(nnet_i_q, nnet_o_q, proc_id))
175
176    return nnet_par_infos
def start_nnet_fn_runners( get_nnet: Callable[[], torch.nn.modules.module.Module], nnet_par_infos: List[NNetParInfo], model_file: str, device: torch.device, on_gpu: bool, batch_size: Optional[int] = None) -> List[multiprocessing.process.BaseProcess]:
179def start_nnet_fn_runners(get_nnet: Callable[[], nn.Module], nnet_par_infos: List[NNetParInfo], model_file: str,
180                          device: torch.device, on_gpu: bool,
181                          batch_size: Optional[int] = None) -> List[BaseProcess]:
182    ctx = get_context("spawn")
183
184    nnet_i_q: Queue = nnet_par_infos[0].nnet_i_q
185    nnet_o_qs: List[Queue] = [nnet_par_info.nnet_o_q for nnet_par_info in nnet_par_infos]
186
187    # initialize heuristic procs
188    if ('CUDA_VISIBLE_DEVICES' in os.environ) and (len(os.environ['CUDA_VISIBLE_DEVICES']) > 0):
189        gpu_nums = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(",")]
190    else:
191        gpu_nums = [-1]
192
193    nnet_procs: List[BaseProcess] = []
194    for gpu_num in gpu_nums:
195        nnet_fn_procs = ctx.Process(target=nnet_fn_runner,
196                                    args=(nnet_i_q, nnet_o_qs, model_file, device, on_gpu, gpu_num, get_nnet,
197                                          batch_size))
198        nnet_fn_procs.daemon = True
199        nnet_fn_procs.start()
200        nnet_procs.append(nnet_fn_procs)
201
202    return nnet_procs
def stop_nnet_runners( nnet_fn_procs: List[multiprocessing.process.BaseProcess], nnet_par_infos: List[NNetParInfo]) -> None:
205def stop_nnet_runners(nnet_fn_procs: List[BaseProcess], nnet_par_infos: List[NNetParInfo]) -> None:
206    for _ in nnet_fn_procs:
207        nnet_par_infos[0].nnet_i_q.put((None, None))
208
209    for heur_fn_proc in nnet_fn_procs:
210        heur_fn_proc.join()
NNetCallable = typing.Callable[..., typing.Any]
class NNetPar(abc.ABC, typing.Generic[~NNetFn]):
217class NNetPar(ABC, Generic[NNetFn]):
218    @abstractmethod
219    def get_nnet_fn(self, nnet: nn.Module, batch_size: Optional[int], device: torch.device,
220                    update_num: Optional[int]) -> NNetFn:
221        pass
222
223    @abstractmethod
224    def get_nnet_par_fn(self, nnet_par_info: NNetParInfo, update_num: Optional[int]) -> NNetFn:
225        pass
226
227    @abstractmethod
228    def get_nnet(self) -> nn.Module:
229        pass

Helper class that provides a standard way to create an ABC using inheritance.

@abstractmethod
def get_nnet_fn( self, nnet: torch.nn.modules.module.Module, batch_size: Optional[int], device: torch.device, update_num: Optional[int]) -> ~NNetFn:
218    @abstractmethod
219    def get_nnet_fn(self, nnet: nn.Module, batch_size: Optional[int], device: torch.device,
220                    update_num: Optional[int]) -> NNetFn:
221        pass
@abstractmethod
def get_nnet_par_fn( self, nnet_par_info: NNetParInfo, update_num: Optional[int]) -> ~NNetFn:
223    @abstractmethod
224    def get_nnet_par_fn(self, nnet_par_info: NNetParInfo, update_num: Optional[int]) -> NNetFn:
225        pass
@abstractmethod
def get_nnet(self) -> torch.nn.modules.module.Module:
227    @abstractmethod
228    def get_nnet(self) -> nn.Module:
229        pass
def get_nnet_par_out( inputs_nnet: List[numpy.ndarray[tuple[int, ...], numpy.dtype[+_ScalarType_co]]], nnet_par_info: NNetParInfo) -> List[numpy.ndarray[tuple[int, ...], numpy.dtype[+_ScalarType_co]]]:
232def get_nnet_par_out(inputs_nnet: List[NDArray], nnet_par_info: NNetParInfo) -> List[NDArray]:
233    inputs_nnet_shm: List[SharedNDArray] = [np_to_shnd(inputs_nnet_i)
234                                            for input_idx, inputs_nnet_i in enumerate(inputs_nnet)]
235
236    nnet_par_info.nnet_i_q.put((nnet_par_info.proc_id, inputs_nnet_shm))
237
238    out_shm_l: List[SharedNDArray] = nnet_par_info.nnet_o_q.get()
239    out_l: List[NDArray] = [out_shm.array.copy() for out_shm in out_shm_l]
240
241    for arr_shm in inputs_nnet_shm + out_shm_l:
242        arr_shm.close()
243        arr_shm.unlink()
244
245    return out_l