diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 1640f8a678c6..cf4fb39b9d47 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -115,9 +115,7 @@ def select_device(device="", batch_size=0, newline=True): device = str(device).strip().lower().replace("cuda:", "").replace("none", "") # to string, 'cuda:0' to '0' cpu = device == "cpu" mps = device == "mps" # Apple Metal Performance Shaders (MPS) - if cpu or mps: - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False - elif device: # non-cpu device requested + if device and not cpu and not mps: # non-cpu device requested os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available() assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", "")), ( f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"