Skip to content

Commit c33adec

Browse files
ji-huazhongjihuazhong
andauthored
Add Ascend NPU accelerator support (#1676)
* add Ascend NPU accelerator support * fix code styles * enable accelerate test on npu * fix typo&code styles --------- Co-authored-by: jihuazhong <[email protected]>
1 parent 518c206 commit c33adec

File tree

15 files changed

+165
-36
lines changed

15 files changed

+165
-36
lines changed

src/accelerate/accelerator.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
is_fp8_available,
7676
is_ipex_available,
7777
is_megatron_lm_available,
78+
is_npu_available,
7879
is_safetensors_available,
7980
is_torch_version,
8081
is_tpu_available,
@@ -413,13 +414,15 @@ def __init__(
413414
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
414415
):
415416
self.native_amp = True
416-
if self.device.type not in ("cuda", "mps"):
417+
if self.device.type not in ("cuda", "mps", "npu"):
417418
raise ValueError(err.format(mode="fp16", requirement="a GPU"))
418419
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
419420
if self.distributed_type == DistributedType.FSDP:
420421
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
421422

422423
self.scaler = ShardedGradScaler(**kwargs)
424+
elif is_npu_available():
425+
self.scaler = torch.npu.amp.GradScaler(**kwargs)
423426
else:
424427
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
425428

@@ -965,7 +968,7 @@ def join_uneven_inputs(self, joinables, even_batches=None):
965968
... optimizer.zero_grad()
966969
```
967970
"""
968-
if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_XPU):
971+
if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_XPU):
969972
dl_even_batches_values = []
970973

971974
if even_batches is not None:
@@ -1292,7 +1295,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
12921295
model._original_forward = model.forward
12931296
model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward
12941297
if self.mixed_precision == "fp16":
1295-
new_forward = torch.cuda.amp.autocast(dtype=torch.float16)(model_forward_func)
1298+
if is_npu_available():
1299+
new_forward = torch.npu.amp.autocast(dtype=torch.float16)(model_forward_func)
1300+
else:
1301+
new_forward = torch.cuda.amp.autocast(dtype=torch.float16)(model_forward_func)
12961302
elif self.mixed_precision == "bf16" and self.distributed_type != DistributedType.TPU:
12971303
new_forward = torch.autocast(device_type=self.device.type, dtype=torch.bfloat16)(model_forward_func)
12981304

@@ -1324,7 +1330,11 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
13241330
)
13251331
model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)
13261332
if not evaluation_mode:
1327-
if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_XPU):
1333+
if self.distributed_type in (
1334+
DistributedType.MULTI_GPU,
1335+
DistributedType.MULTI_NPU,
1336+
DistributedType.MULTI_XPU,
1337+
):
13281338
if any(p.requires_grad for p in model.parameters()):
13291339
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
13301340
model = torch.nn.parallel.DistributedDataParallel(
@@ -2686,7 +2696,10 @@ def load_state(self, input_dir: str, **load_model_func_kwargs):
26862696

26872697
map_location = load_model_func_kwargs.pop("map_location", None)
26882698
if map_location is None:
2689-
if self.num_processes > 1 and self.distributed_type == DistributedType.MULTI_GPU:
2699+
if self.num_processes > 1 and self.distributed_type in (
2700+
DistributedType.MULTI_GPU,
2701+
DistributedType.MULTI_NPU,
2702+
):
26902703
map_location = "on_device"
26912704
else:
26922705
map_location = "cpu"

src/accelerate/commands/config/cluster.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
def get_cluster_input():
4848
distributed_type = _ask_options(
4949
"Which type of machine are you using?",
50-
["No distributed training", "multi-CPU", "multi-XPU", "multi-GPU", "TPU"],
50+
["No distributed training", "multi-CPU", "multi-XPU", "multi-GPU", "multi-NPU", "TPU"],
5151
_convert_distributed_mode,
5252
)
5353

@@ -60,7 +60,12 @@ def get_cluster_input():
6060
rdzv_backend = "static"
6161
same_network = True
6262

63-
if distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.MULTI_CPU]:
63+
if distributed_type in [
64+
DistributedType.MULTI_GPU,
65+
DistributedType.MULTI_GPU,
66+
DistributedType.MULTI_XPU,
67+
DistributedType.MULTI_CPU,
68+
]:
6469
num_machines = _ask_field(
6570
"How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
6671
int,
@@ -110,7 +115,11 @@ def get_cluster_input():
110115
default=False,
111116
error_message="Please enter yes or no.",
112117
)
113-
if not use_cpu and is_xpu_available() and distributed_type not in [DistributedType.MULTI_GPU, DistributedType.TPU]:
118+
if (
119+
not use_cpu
120+
and is_xpu_available()
121+
and distributed_type not in [DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.TPU]
122+
):
114123
ipex_config["use_xpu"] = _ask_field(
115124
"Do you want to use XPU plugin to speed up training on XPU? [yes/NO]:",
116125
_convert_yes_no_to_bool,
@@ -444,6 +453,7 @@ def get_cluster_input():
444453
DistributedType.MULTI_CPU,
445454
DistributedType.MULTI_XPU,
446455
DistributedType.MULTI_GPU,
456+
DistributedType.MULTI_NPU,
447457
DistributedType.TPU,
448458
]:
449459
machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "")
@@ -468,7 +478,13 @@ def get_cluster_input():
468478
num_processes = 1
469479

470480
if (
471-
distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.NO]
481+
distributed_type
482+
in [
483+
DistributedType.MULTI_GPU,
484+
DistributedType.MULTI_NPU,
485+
DistributedType.MULTI_XPU,
486+
DistributedType.NO,
487+
]
472488
and not use_cpu
473489
and not use_mps
474490
):

src/accelerate/commands/config/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _convert_compute_environment(value):
6666

6767
def _convert_distributed_mode(value):
6868
value = int(value)
69-
return DistributedType(["NO", "MULTI_CPU", "MULTI_XPU", "MULTI_GPU", "TPU"][value])
69+
return DistributedType(["NO", "MULTI_CPU", "MULTI_XPU", "MULTI_GPU", "MULTI_NPU", "TPU"][value])
7070

7171

7272
def _convert_dynamo_backend(value):

src/accelerate/commands/config/default.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020

21-
from ...utils import is_xpu_available
21+
from ...utils import is_npu_available, is_xpu_available
2222
from .config_args import ClusterConfig, default_json_config_file
2323
from .config_utils import SubcommandHelpFormatter
2424

@@ -73,6 +73,14 @@ def write_basic_config(mixed_precision="no", save_location: str = default_json_c
7373
config["distributed_type"] = "MULTI_XPU"
7474
else:
7575
config["distributed_type"] = "NO"
76+
elif is_npu_available():
77+
num_npus = torch.npu.device_count()
78+
config["num_processes"] = num_npus
79+
config["use_cpu"] = False
80+
if num_npus > 1:
81+
config["distributed_type"] = "MULTI_NPU"
82+
else:
83+
config["distributed_type"] = "NO"
7684
else:
7785
num_xpus = 0
7886
config["use_cpu"] = True

src/accelerate/commands/env.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from accelerate import __version__ as version
2626
from accelerate.commands.config import default_config_file, load_config_from_file
2727

28-
from ..utils import is_xpu_available
28+
from ..utils import is_npu_available, is_xpu_available
2929

3030

3131
def env_command_parser(subparsers=None):
@@ -47,6 +47,7 @@ def env_command(args):
4747
pt_version = torch.__version__
4848
pt_cuda_available = torch.cuda.is_available()
4949
pt_xpu_available = is_xpu_available()
50+
pt_npu_available = is_npu_available()
5051

5152
accelerate_config = "Not found"
5253
# Get the default from the config file.
@@ -60,6 +61,7 @@ def env_command(args):
6061
"Numpy version": np.__version__,
6162
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
6263
"PyTorch XPU available": str(pt_xpu_available),
64+
"PyTorch NPU available": str(pt_npu_available),
6365
"System RAM": f"{psutil.virtual_memory().total / 1024 ** 3:.2f} GB",
6466
}
6567
if pt_cuda_available:

src/accelerate/commands/launch.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
_filter_args,
3737
is_bf16_available,
3838
is_deepspeed_available,
39+
is_npu_available,
3940
is_rich_available,
4041
is_sagemaker_available,
4142
is_torch_version,
@@ -828,7 +829,10 @@ def _validate_launch_command(args):
828829
):
829830
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
830831
args.multi_gpu = (
831-
True if defaults.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_XPU) else False
832+
True
833+
if defaults.distributed_type
834+
in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_XPU)
835+
else False
832836
)
833837
args.tpu = defaults.distributed_type == DistributedType.TPU
834838
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
@@ -896,11 +900,15 @@ def _validate_launch_command(args):
896900
if args.num_processes is None:
897901
if args.use_xpu and is_xpu_available():
898902
args.num_processes = torch.xpu.device_count()
903+
elif is_npu_available():
904+
args.num_processes = torch.npu.device_count()
899905
else:
900906
args.num_processes = torch.cuda.device_count()
901907
warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`")
902908
if not args.multi_gpu and (
903-
(args.use_xpu and is_xpu_available() and torch.xpu.device_count() > 1) or (torch.cuda.device_count() > 1)
909+
(args.use_xpu and is_xpu_available() and torch.xpu.device_count() > 1)
910+
or (is_npu_available() and torch.npu.device_count() > 1)
911+
or (torch.cuda.device_count() > 1)
904912
):
905913
warned.append(
906914
"\t\tMore than one GPU was found, enabling multi-GPU training.\n"

src/accelerate/state.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
is_fp8_available,
3636
is_ipex_available,
3737
is_mps_available,
38+
is_npu_available,
3839
is_tpu_available,
3940
is_xpu_available,
4041
parse_choice_from_env,
@@ -195,6 +196,19 @@ def __init__(self, cpu: bool = False, **kwargs):
195196
if self.device is None:
196197
self.device = torch.device("cuda", self.local_process_index)
197198
torch.cuda.set_device(self.device)
199+
elif is_npu_available() and not cpu and int(os.environ.get("LOCAL_RANK", -1)) != -1:
200+
self.distributed_type = DistributedType.MULTI_NPU
201+
if not torch.distributed.is_initialized():
202+
# Backend is not set by the user, we set it here
203+
kwargs.pop("backend", None)
204+
self.backend = "hccl"
205+
torch.distributed.init_process_group(backend=self.backend, **kwargs)
206+
self.num_processes = torch.distributed.get_world_size()
207+
self.process_index = torch.distributed.get_rank()
208+
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
209+
if self.device is None:
210+
self.device = torch.device("npu", self.local_process_index)
211+
torch.npu.set_device(self.device)
198212
elif get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1:
199213
if not cpu and is_xpu_available():
200214
self.distributed_type = DistributedType.MULTI_XPU
@@ -343,6 +357,7 @@ def wait_for_everyone(self):
343357
"""
344358
if self.distributed_type in (
345359
DistributedType.MULTI_GPU,
360+
DistributedType.MULTI_NPU,
346361
DistributedType.MULTI_XPU,
347362
DistributedType.MULTI_CPU,
348363
DistributedType.DEEPSPEED,
@@ -649,6 +664,7 @@ def default_device(self) -> torch.device:
649664
Returns the default device which is:
650665
- MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True.
651666
- CUDA if `torch.cuda.is_available()`
667+
- NPU if `is_npu_available()`
652668
- CPU otherwise
653669
"""
654670
if is_mps_available():
@@ -658,6 +674,8 @@ def default_device(self) -> torch.device:
658674
return torch.device("cuda")
659675
elif is_xpu_available():
660676
return torch.device("xpu:0")
677+
elif is_npu_available():
678+
return torch.device("npu")
661679
else:
662680
return torch.device("cpu")
663681

src/accelerate/test_utils/scripts/test_script.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
gather,
3434
is_bf16_available,
3535
is_ipex_available,
36+
is_npu_available,
3637
is_xpu_available,
3738
set_seed,
3839
synchronize_rng_states,
@@ -358,7 +359,7 @@ def training_check():
358359

359360
accelerator.print("Training yielded the same results on one CPU or distributes setup with batch split.")
360361

361-
if torch.cuda.is_available():
362+
if torch.cuda.is_available() or is_npu_available():
362363
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
363364
print("FP16 training check.")
364365
AcceleratorState._reset_state()

src/accelerate/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
is_megatron_lm_available,
5252
is_mlflow_available,
5353
is_mps_available,
54+
is_npu_available,
5455
is_rich_available,
5556
is_safetensors_available,
5657
is_sagemaker_available,

src/accelerate/utils/dataclasses.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ class DistributedType(str, enum.Enum):
182182
- **NO** -- Not a distributed environment, just a single process.
183183
- **MULTI_CPU** -- Distributed on multiple CPU nodes.
184184
- **MULTI_GPU** -- Distributed on multiple GPUs.
185+
- **MULTI_NPU** -- Distributed on multiple NPUs.
185186
- **MULTI_XPU** -- Distributed on multiple XPUs.
186187
- **DEEPSPEED** -- Using DeepSpeed.
187188
- **TPU** -- Distributed on TPUs.
@@ -191,6 +192,7 @@ class DistributedType(str, enum.Enum):
191192
NO = "NO"
192193
MULTI_CPU = "MULTI_CPU"
193194
MULTI_GPU = "MULTI_GPU"
195+
MULTI_NPU = "MULTI_NPU"
194196
MULTI_XPU = "MULTI_XPU"
195197
DEEPSPEED = "DEEPSPEED"
196198
FSDP = "FSDP"
@@ -335,6 +337,7 @@ class PrecisionType(BaseEnum):
335337
class RNGType(BaseEnum):
336338
TORCH = "torch"
337339
CUDA = "cuda"
340+
NPU = "npu"
338341
XLA = "xla"
339342
XPU = "xpu"
340343
GENERATOR = "generator"

0 commit comments

Comments
 (0)