Skip to content

Commit f68702f

Browse files
committed
Update IPEX libs
1 parent 386b733 commit f68702f

File tree

6 files changed

+337
-563
lines changed

6 files changed

+337
-563
lines changed

library/device_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
import gc
33

44
import torch
5+
try:
6+
# intel gpu support for pytorch older than 2.5
7+
# ipex is not needed after pytorch 2.5
8+
import intel_extension_for_pytorch as ipex # noqa
9+
except Exception:
10+
pass
11+
512

613
try:
714
HAS_CUDA = torch.cuda.is_available()
@@ -14,8 +21,6 @@
1421
HAS_MPS = False
1522

1623
try:
17-
import intel_extension_for_pytorch as ipex # noqa
18-
1924
HAS_XPU = torch.xpu.is_available()
2025
except Exception:
2126
HAS_XPU = False
@@ -69,7 +74,7 @@ def init_ipex():
6974
7075
This function should run right after importing torch and before doing anything else.
7176
72-
If IPEX is not available, this function does nothing.
77+
If xpu is not available, this function does nothing.
7378
"""
7479
try:
7580
if HAS_XPU:

library/ipex/__init__.py

Lines changed: 104 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import sys
33
import contextlib
44
import torch
5-
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
5+
try:
6+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
7+
legacy = True
8+
except Exception:
9+
legacy = False
610
from .hijacks import ipex_hijacks
711

812
# pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -12,6 +16,13 @@ def ipex_init(): # pylint: disable=too-many-statements
1216
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
1317
return True, "Skipping IPEX hijack"
1418
else:
19+
try: # force xpu device on torch compile and triton
20+
torch._inductor.utils.GPU_TYPES = ["xpu"]
21+
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
22+
from triton import backends as triton_backends # pylint: disable=import-error
23+
triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False
24+
except Exception:
25+
pass
1526
# Replace cuda with xpu:
1627
torch.cuda.current_device = torch.xpu.current_device
1728
torch.cuda.current_stream = torch.xpu.current_stream
@@ -26,84 +37,99 @@ def ipex_init(): # pylint: disable=too-many-statements
2637
torch.cuda.is_current_stream_capturing = lambda: False
2738
torch.cuda.set_device = torch.xpu.set_device
2839
torch.cuda.stream = torch.xpu.stream
29-
torch.cuda.synchronize = torch.xpu.synchronize
3040
torch.cuda.Event = torch.xpu.Event
3141
torch.cuda.Stream = torch.xpu.Stream
32-
torch.cuda.FloatTensor = torch.xpu.FloatTensor
3342
torch.Tensor.cuda = torch.Tensor.xpu
3443
torch.Tensor.is_cuda = torch.Tensor.is_xpu
3544
torch.nn.Module.cuda = torch.nn.Module.xpu
36-
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
37-
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
38-
torch.cuda._initialized = torch.xpu.lazy_init._initialized
39-
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
40-
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
41-
torch.cuda._tls = torch.xpu.lazy_init._tls
42-
torch.cuda.threading = torch.xpu.lazy_init.threading
43-
torch.cuda.traceback = torch.xpu.lazy_init.traceback
4445
torch.cuda.Optional = torch.xpu.Optional
4546
torch.cuda.__cached__ = torch.xpu.__cached__
4647
torch.cuda.__loader__ = torch.xpu.__loader__
47-
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
4848
torch.cuda.Tuple = torch.xpu.Tuple
4949
torch.cuda.streams = torch.xpu.streams
50-
torch.cuda._lazy_new = torch.xpu._lazy_new
51-
torch.cuda.FloatStorage = torch.xpu.FloatStorage
5250
torch.cuda.Any = torch.xpu.Any
5351
torch.cuda.__doc__ = torch.xpu.__doc__
5452
torch.cuda.default_generators = torch.xpu.default_generators
55-
torch.cuda.HalfTensor = torch.xpu.HalfTensor
5653
torch.cuda._get_device_index = torch.xpu._get_device_index
5754
torch.cuda.__path__ = torch.xpu.__path__
58-
torch.cuda.Device = torch.xpu.Device
59-
torch.cuda.IntTensor = torch.xpu.IntTensor
60-
torch.cuda.ByteStorage = torch.xpu.ByteStorage
6155
torch.cuda.set_stream = torch.xpu.set_stream
62-
torch.cuda.BoolStorage = torch.xpu.BoolStorage
63-
torch.cuda.os = torch.xpu.os
6456
torch.cuda.torch = torch.xpu.torch
65-
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
6657
torch.cuda.Union = torch.xpu.Union
67-
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
68-
torch.cuda.ShortTensor = torch.xpu.ShortTensor
69-
torch.cuda.LongTensor = torch.xpu.LongTensor
70-
torch.cuda.IntStorage = torch.xpu.IntStorage
71-
torch.cuda.LongStorage = torch.xpu.LongStorage
7258
torch.cuda.__annotations__ = torch.xpu.__annotations__
7359
torch.cuda.__package__ = torch.xpu.__package__
7460
torch.cuda.__builtins__ = torch.xpu.__builtins__
75-
torch.cuda.CharTensor = torch.xpu.CharTensor
7661
torch.cuda.List = torch.xpu.List
7762
torch.cuda._lazy_init = torch.xpu._lazy_init
78-
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
79-
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
80-
torch.cuda.ByteTensor = torch.xpu.ByteTensor
8163
torch.cuda.StreamContext = torch.xpu.StreamContext
82-
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
83-
torch.cuda.ShortStorage = torch.xpu.ShortStorage
8464
torch.cuda._lazy_call = torch.xpu._lazy_call
85-
torch.cuda.HalfStorage = torch.xpu.HalfStorage
8665
torch.cuda.random = torch.xpu.random
8766
torch.cuda._device = torch.xpu._device
88-
torch.cuda.classproperty = torch.xpu.classproperty
8967
torch.cuda.__name__ = torch.xpu.__name__
9068
torch.cuda._device_t = torch.xpu._device_t
91-
torch.cuda.warnings = torch.xpu.warnings
9269
torch.cuda.__spec__ = torch.xpu.__spec__
93-
torch.cuda.BoolTensor = torch.xpu.BoolTensor
94-
torch.cuda.CharStorage = torch.xpu.CharStorage
9570
torch.cuda.__file__ = torch.xpu.__file__
96-
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
9771
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
9872

73+
if legacy:
74+
torch.cuda.os = torch.xpu.os
75+
torch.cuda.Device = torch.xpu.Device
76+
torch.cuda.warnings = torch.xpu.warnings
77+
torch.cuda.classproperty = torch.xpu.classproperty
78+
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
79+
if float(ipex.__version__[:3]) < 2.3:
80+
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
81+
torch.cuda._initialized = torch.xpu.lazy_init._initialized
82+
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
83+
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
84+
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
85+
torch.cuda._tls = torch.xpu.lazy_init._tls
86+
torch.cuda.threading = torch.xpu.lazy_init.threading
87+
torch.cuda.traceback = torch.xpu.lazy_init.traceback
88+
torch.cuda._lazy_new = torch.xpu._lazy_new
89+
90+
torch.cuda.FloatTensor = torch.xpu.FloatTensor
91+
torch.cuda.FloatStorage = torch.xpu.FloatStorage
92+
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
93+
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
94+
torch.cuda.HalfTensor = torch.xpu.HalfTensor
95+
torch.cuda.HalfStorage = torch.xpu.HalfStorage
96+
torch.cuda.ByteTensor = torch.xpu.ByteTensor
97+
torch.cuda.ByteStorage = torch.xpu.ByteStorage
98+
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
99+
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
100+
torch.cuda.ShortTensor = torch.xpu.ShortTensor
101+
torch.cuda.ShortStorage = torch.xpu.ShortStorage
102+
torch.cuda.LongTensor = torch.xpu.LongTensor
103+
torch.cuda.LongStorage = torch.xpu.LongStorage
104+
torch.cuda.IntTensor = torch.xpu.IntTensor
105+
torch.cuda.IntStorage = torch.xpu.IntStorage
106+
torch.cuda.CharTensor = torch.xpu.CharTensor
107+
torch.cuda.CharStorage = torch.xpu.CharStorage
108+
torch.cuda.BoolTensor = torch.xpu.BoolTensor
109+
torch.cuda.BoolStorage = torch.xpu.BoolStorage
110+
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
111+
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
112+
113+
if not legacy or float(ipex.__version__[:3]) >= 2.3:
114+
torch.cuda._initialization_lock = torch.xpu._initialization_lock
115+
torch.cuda._initialized = torch.xpu._initialized
116+
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
117+
torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker
118+
torch.cuda._queued_calls = torch.xpu._queued_calls
119+
torch.cuda._tls = torch.xpu._tls
120+
torch.cuda.threading = torch.xpu.threading
121+
torch.cuda.traceback = torch.xpu.traceback
122+
99123
# Memory:
100-
torch.cuda.memory = torch.xpu.memory
101124
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
102125
torch.xpu.empty_cache = lambda: None
103126
torch.cuda.empty_cache = torch.xpu.empty_cache
127+
128+
if legacy:
129+
torch.cuda.memory_summary = torch.xpu.memory_summary
130+
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
131+
torch.cuda.memory = torch.xpu.memory
104132
torch.cuda.memory_stats = torch.xpu.memory_stats
105-
torch.cuda.memory_summary = torch.xpu.memory_summary
106-
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
107133
torch.cuda.memory_allocated = torch.xpu.memory_allocated
108134
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
109135
torch.cuda.memory_reserved = torch.xpu.memory_reserved
@@ -128,52 +154,64 @@ def ipex_init(): # pylint: disable=too-many-statements
128154
torch.cuda.initial_seed = torch.xpu.initial_seed
129155

130156
# AMP:
131-
torch.cuda.amp = torch.xpu.amp
132-
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
133-
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
157+
if legacy:
158+
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
159+
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
160+
torch.cuda.amp = torch.xpu.amp
161+
if float(ipex.__version__[:3]) < 2.3:
162+
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
163+
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
134164

135-
if not hasattr(torch.cuda.amp, "common"):
136-
torch.cuda.amp.common = contextlib.nullcontext()
137-
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
165+
if not hasattr(torch.cuda.amp, "common"):
166+
torch.cuda.amp.common = contextlib.nullcontext()
167+
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
138168

139-
try:
140-
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
141-
except Exception: # pylint: disable=broad-exception-caught
142169
try:
143-
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
144-
gradscaler_init()
145170
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
146171
except Exception: # pylint: disable=broad-exception-caught
147-
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
172+
try:
173+
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
174+
gradscaler_init()
175+
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
176+
except Exception: # pylint: disable=broad-exception-caught
177+
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
148178

149179
# C
150-
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
151-
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
152-
ipex._C._DeviceProperties.major = 2024
153-
ipex._C._DeviceProperties.minor = 0
180+
if legacy and float(ipex.__version__[:3]) < 2.3:
181+
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
182+
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
183+
ipex._C._DeviceProperties.major = 12
184+
ipex._C._DeviceProperties.minor = 1
185+
else:
186+
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
187+
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
188+
torch._C._XpuDeviceProperties.major = 12
189+
torch._C._XpuDeviceProperties.minor = 1
154190

155191
# Fix functions with ipex:
156-
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
192+
# torch.xpu.mem_get_info always returns the total memory as free memory
193+
torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
194+
torch.cuda.mem_get_info = torch.xpu.mem_get_info
157195
torch._utils._get_available_device_type = lambda: "xpu"
158196
torch.has_cuda = True
159197
torch.cuda.has_half = True
160198
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
161199
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
162200
torch.backends.cuda.is_built = lambda *args, **kwargs: True
163201
torch.version.cuda = "12.1"
164-
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
202+
torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"]
203+
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
165204
torch.cuda.get_device_properties.major = 12
166205
torch.cuda.get_device_properties.minor = 1
167206
torch.cuda.ipc_collect = lambda *args, **kwargs: None
168207
torch.cuda.utilization = lambda *args, **kwargs: 0
169208

170-
ipex_hijacks()
171-
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
172-
try:
173-
from .diffusers import ipex_diffusers
174-
ipex_diffusers()
175-
except Exception: # pylint: disable=broad-exception-caught
176-
pass
209+
device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy)
210+
try:
211+
from .diffusers import ipex_diffusers
212+
ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb)
213+
except Exception: # pylint: disable=broad-exception-caught
214+
pass
177215
torch.cuda.is_xpu_hijacked = True
178216
except Exception as e:
179217
return False, e

0 commit comments

Comments
 (0)