|
75 | 75 | is_fp8_available, |
76 | 76 | is_ipex_available, |
77 | 77 | is_megatron_lm_available, |
| 78 | + is_npu_available, |
78 | 79 | is_safetensors_available, |
79 | 80 | is_torch_version, |
80 | 81 | is_tpu_available, |
@@ -413,13 +414,15 @@ def __init__( |
413 | 414 | and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM) |
414 | 415 | ): |
415 | 416 | self.native_amp = True |
416 | | - if self.device.type not in ("cuda", "mps"): |
| 417 | + if self.device.type not in ("cuda", "mps", "npu"): |
417 | 418 | raise ValueError(err.format(mode="fp16", requirement="a GPU")) |
418 | 419 | kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {} |
419 | 420 | if self.distributed_type == DistributedType.FSDP: |
420 | 421 | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler |
421 | 422 |
|
422 | 423 | self.scaler = ShardedGradScaler(**kwargs) |
| 424 | + elif is_npu_available(): |
| 425 | + self.scaler = torch.npu.amp.GradScaler(**kwargs) |
423 | 426 | else: |
424 | 427 | self.scaler = torch.cuda.amp.GradScaler(**kwargs) |
425 | 428 |
|
@@ -965,7 +968,7 @@ def join_uneven_inputs(self, joinables, even_batches=None): |
965 | 968 | ... optimizer.zero_grad() |
966 | 969 | ``` |
967 | 970 | """ |
968 | | - if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_XPU): |
| 971 | + if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_XPU): |
969 | 972 | dl_even_batches_values = [] |
970 | 973 |
|
971 | 974 | if even_batches is not None: |
@@ -1292,7 +1295,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e |
1292 | 1295 | model._original_forward = model.forward |
1293 | 1296 | model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward |
1294 | 1297 | if self.mixed_precision == "fp16": |
1295 | | - new_forward = torch.cuda.amp.autocast(dtype=torch.float16)(model_forward_func) |
| 1298 | + if is_npu_available(): |
| 1299 | + new_forward = torch.npu.amp.autocast(dtype=torch.float16)(model_forward_func) |
| 1300 | + else: |
| 1301 | + new_forward = torch.cuda.amp.autocast(dtype=torch.float16)(model_forward_func) |
1296 | 1302 | elif self.mixed_precision == "bf16" and self.distributed_type != DistributedType.TPU: |
1297 | 1303 | new_forward = torch.autocast(device_type=self.device.type, dtype=torch.bfloat16)(model_forward_func) |
1298 | 1304 |
|
@@ -1324,7 +1330,11 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e |
1324 | 1330 | ) |
1325 | 1331 | model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward) |
1326 | 1332 | if not evaluation_mode: |
1327 | | - if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_XPU): |
| 1333 | + if self.distributed_type in ( |
| 1334 | + DistributedType.MULTI_GPU, |
| 1335 | + DistributedType.MULTI_NPU, |
| 1336 | + DistributedType.MULTI_XPU, |
| 1337 | + ): |
1328 | 1338 | if any(p.requires_grad for p in model.parameters()): |
1329 | 1339 | kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} |
1330 | 1340 | model = torch.nn.parallel.DistributedDataParallel( |
@@ -2686,7 +2696,10 @@ def load_state(self, input_dir: str, **load_model_func_kwargs): |
2686 | 2696 |
|
2687 | 2697 | map_location = load_model_func_kwargs.pop("map_location", None) |
2688 | 2698 | if map_location is None: |
2689 | | - if self.num_processes > 1 and self.distributed_type == DistributedType.MULTI_GPU: |
| 2699 | + if self.num_processes > 1 and self.distributed_type in ( |
| 2700 | + DistributedType.MULTI_GPU, |
| 2701 | + DistributedType.MULTI_NPU, |
| 2702 | + ): |
2690 | 2703 | map_location = "on_device" |
2691 | 2704 | else: |
2692 | 2705 | map_location = "cpu" |
|
0 commit comments