2424from ... import oscar as mo
2525from ...core .entrypoints import init_extension_entrypoints
2626from ...lib .aio import get_isolation , stop_isolation
27- from ...resource import cpu_count , cuda_count
27+ from ...resource import cpu_count , cuda_count , mem_total , Resource
2828from ...services import NodeRole
2929from ...typing import ClusterType , ClientType
3030from ..utils import get_third_party_modules_from_config
@@ -51,6 +51,7 @@ async def new_cluster_in_isolation(
5151 address : str = "0.0.0.0" ,
5252 n_worker : int = 1 ,
5353 n_cpu : Union [int , str ] = "auto" ,
54+ mem_bytes : Union [int , str ] = "auto" ,
5455 cuda_devices : Union [List [int ], str ] = "auto" ,
5556 subprocess_start_method : str = None ,
5657 backend : str = None ,
@@ -65,6 +66,7 @@ async def new_cluster_in_isolation(
6566 address ,
6667 n_worker ,
6768 n_cpu ,
69+ mem_bytes ,
6870 cuda_devices ,
6971 subprocess_start_method ,
7072 config ,
@@ -79,6 +81,7 @@ async def new_cluster(
7981 address : str = "0.0.0.0" ,
8082 n_worker : int = 1 ,
8183 n_cpu : Union [int , str ] = "auto" ,
84+ mem_bytes : Union [int , str ] = "auto" ,
8285 cuda_devices : Union [List [int ], str ] = "auto" ,
8386 subprocess_start_method : str = None ,
8487 config : Union [str , Dict ] = None ,
@@ -91,6 +94,7 @@ async def new_cluster(
9194 address ,
9295 n_worker = n_worker ,
9396 n_cpu = n_cpu ,
97+ mem_bytes = mem_bytes ,
9498 cuda_devices = cuda_devices ,
9599 subprocess_start_method = subprocess_start_method ,
96100 config = config ,
@@ -116,6 +120,7 @@ def __init__(
116120 address : str = "0.0.0.0" ,
117121 n_worker : int = 1 ,
118122 n_cpu : Union [int , str ] = "auto" ,
123+ mem_bytes : Union [int , str ] = "auto" ,
119124 cuda_devices : Union [List [int ], List [List [int ]], str ] = "auto" ,
120125 subprocess_start_method : str = None ,
121126 config : Union [str , Dict ] = None ,
@@ -132,6 +137,7 @@ def __init__(
132137 self ._subprocess_start_method = subprocess_start_method
133138 self ._config = config
134139 self ._n_cpu = cpu_count () if n_cpu == "auto" else n_cpu
140+ self ._mem_bytes = mem_total () if mem_bytes == "auto" else mem_bytes
135141 self ._n_supervisor_process = n_supervisor_process
136142 if cuda_devices == "auto" :
137143 total = cuda_count ()
@@ -148,19 +154,22 @@ def __init__(
148154
149155 self ._n_worker = n_worker
150156 self ._web = web
151- self ._bands_to_slot = bands_to_slot = []
157+ self ._bands_to_resource = bands_to_resource = []
152158 worker_cpus = self ._n_cpu // n_worker
153159 if sum (len (devices ) for devices in devices_list ) == 0 :
154160 assert worker_cpus > 0 , (
155161 f"{ self ._n_cpu } cpus are not enough "
156162 f"for { n_worker } , try to decrease workers."
157163 )
164+ mem_bytes = self ._mem_bytes // n_worker
158165 for _ , devices in zip (range (n_worker ), devices_list ):
159- worker_band_to_slot = dict ()
160- worker_band_to_slot ["numa-0" ] = worker_cpus
166+ worker_band_to_resource = dict ()
167+ worker_band_to_resource ["numa-0" ] = Resource (
168+ num_cpus = worker_cpus , mem_bytes = mem_bytes
169+ )
161170 for i in devices : # pragma: no cover
162- worker_band_to_slot [f"gpu-{ i } " ] = 1
163- bands_to_slot .append (worker_band_to_slot )
171+ worker_band_to_resource [f"gpu-{ i } " ] = Resource ( num_gpus = 1 )
172+ bands_to_resource .append (worker_band_to_resource )
164173 self ._supervisor_pool = None
165174 self ._worker_pools = []
166175
@@ -211,10 +220,10 @@ async def _start_worker_pools(self):
211220 worker_modules = get_third_party_modules_from_config (
212221 self ._config , NodeRole .WORKER
213222 )
214- for band_to_slot in self ._bands_to_slot :
223+ for band_to_resource in self ._bands_to_resource :
215224 worker_pool = await create_worker_actor_pool (
216225 self ._address ,
217- band_to_slot ,
226+ band_to_resource ,
218227 modules = worker_modules ,
219228 subprocess_start_method = self ._subprocess_start_method ,
220229 metrics = self ._config .get ("metrics" , {}),
@@ -225,11 +234,13 @@ async def _start_service(self):
225234 self ._web = await start_supervisor (
226235 self .supervisor_address , config = self ._config , web = self ._web
227236 )
228- for worker_pool , band_to_slot in zip (self ._worker_pools , self ._bands_to_slot ):
237+ for worker_pool , band_to_resource in zip (
238+ self ._worker_pools , self ._bands_to_resource
239+ ):
229240 await start_worker (
230241 worker_pool .external_address ,
231242 self .supervisor_address ,
232- band_to_slot ,
243+ band_to_resource ,
233244 config = self ._config ,
234245 )
235246
0 commit comments