22import sys
33import contextlib
44import torch
5- import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
5+ try :
6+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
7+ legacy = True
8+ except Exception :
9+ legacy = False
610from .hijacks import ipex_hijacks
711
812# pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -12,6 +16,13 @@ def ipex_init(): # pylint: disable=too-many-statements
1216 if hasattr (torch , "cuda" ) and hasattr (torch .cuda , "is_xpu_hijacked" ) and torch .cuda .is_xpu_hijacked :
1317 return True , "Skipping IPEX hijack"
1418 else :
19+ try : # force xpu device on torch compile and triton
20+ torch ._inductor .utils .GPU_TYPES = ["xpu" ]
21+ torch ._inductor .utils .get_gpu_type = lambda * args , ** kwargs : "xpu"
22+ from triton import backends as triton_backends # pylint: disable=import-error
23+ triton_backends .backends ["nvidia" ].driver .is_active = lambda * args , ** kwargs : False
24+ except Exception :
25+ pass
1526 # Replace cuda with xpu:
1627 torch .cuda .current_device = torch .xpu .current_device
1728 torch .cuda .current_stream = torch .xpu .current_stream
@@ -26,84 +37,99 @@ def ipex_init(): # pylint: disable=too-many-statements
2637 torch .cuda .is_current_stream_capturing = lambda : False
2738 torch .cuda .set_device = torch .xpu .set_device
2839 torch .cuda .stream = torch .xpu .stream
29- torch .cuda .synchronize = torch .xpu .synchronize
3040 torch .cuda .Event = torch .xpu .Event
3141 torch .cuda .Stream = torch .xpu .Stream
32- torch .cuda .FloatTensor = torch .xpu .FloatTensor
3342 torch .Tensor .cuda = torch .Tensor .xpu
3443 torch .Tensor .is_cuda = torch .Tensor .is_xpu
3544 torch .nn .Module .cuda = torch .nn .Module .xpu
36- torch .UntypedStorage .cuda = torch .UntypedStorage .xpu
37- torch .cuda ._initialization_lock = torch .xpu .lazy_init ._initialization_lock
38- torch .cuda ._initialized = torch .xpu .lazy_init ._initialized
39- torch .cuda ._lazy_seed_tracker = torch .xpu .lazy_init ._lazy_seed_tracker
40- torch .cuda ._queued_calls = torch .xpu .lazy_init ._queued_calls
41- torch .cuda ._tls = torch .xpu .lazy_init ._tls
42- torch .cuda .threading = torch .xpu .lazy_init .threading
43- torch .cuda .traceback = torch .xpu .lazy_init .traceback
4445 torch .cuda .Optional = torch .xpu .Optional
4546 torch .cuda .__cached__ = torch .xpu .__cached__
4647 torch .cuda .__loader__ = torch .xpu .__loader__
47- torch .cuda .ComplexFloatStorage = torch .xpu .ComplexFloatStorage
4848 torch .cuda .Tuple = torch .xpu .Tuple
4949 torch .cuda .streams = torch .xpu .streams
50- torch .cuda ._lazy_new = torch .xpu ._lazy_new
51- torch .cuda .FloatStorage = torch .xpu .FloatStorage
5250 torch .cuda .Any = torch .xpu .Any
5351 torch .cuda .__doc__ = torch .xpu .__doc__
5452 torch .cuda .default_generators = torch .xpu .default_generators
55- torch .cuda .HalfTensor = torch .xpu .HalfTensor
5653 torch .cuda ._get_device_index = torch .xpu ._get_device_index
5754 torch .cuda .__path__ = torch .xpu .__path__
58- torch .cuda .Device = torch .xpu .Device
59- torch .cuda .IntTensor = torch .xpu .IntTensor
60- torch .cuda .ByteStorage = torch .xpu .ByteStorage
6155 torch .cuda .set_stream = torch .xpu .set_stream
62- torch .cuda .BoolStorage = torch .xpu .BoolStorage
63- torch .cuda .os = torch .xpu .os
6456 torch .cuda .torch = torch .xpu .torch
65- torch .cuda .BFloat16Storage = torch .xpu .BFloat16Storage
6657 torch .cuda .Union = torch .xpu .Union
67- torch .cuda .DoubleTensor = torch .xpu .DoubleTensor
68- torch .cuda .ShortTensor = torch .xpu .ShortTensor
69- torch .cuda .LongTensor = torch .xpu .LongTensor
70- torch .cuda .IntStorage = torch .xpu .IntStorage
71- torch .cuda .LongStorage = torch .xpu .LongStorage
7258 torch .cuda .__annotations__ = torch .xpu .__annotations__
7359 torch .cuda .__package__ = torch .xpu .__package__
7460 torch .cuda .__builtins__ = torch .xpu .__builtins__
75- torch .cuda .CharTensor = torch .xpu .CharTensor
7661 torch .cuda .List = torch .xpu .List
7762 torch .cuda ._lazy_init = torch .xpu ._lazy_init
78- torch .cuda .BFloat16Tensor = torch .xpu .BFloat16Tensor
79- torch .cuda .DoubleStorage = torch .xpu .DoubleStorage
80- torch .cuda .ByteTensor = torch .xpu .ByteTensor
8163 torch .cuda .StreamContext = torch .xpu .StreamContext
82- torch .cuda .ComplexDoubleStorage = torch .xpu .ComplexDoubleStorage
83- torch .cuda .ShortStorage = torch .xpu .ShortStorage
8464 torch .cuda ._lazy_call = torch .xpu ._lazy_call
85- torch .cuda .HalfStorage = torch .xpu .HalfStorage
8665 torch .cuda .random = torch .xpu .random
8766 torch .cuda ._device = torch .xpu ._device
88- torch .cuda .classproperty = torch .xpu .classproperty
8967 torch .cuda .__name__ = torch .xpu .__name__
9068 torch .cuda ._device_t = torch .xpu ._device_t
91- torch .cuda .warnings = torch .xpu .warnings
9269 torch .cuda .__spec__ = torch .xpu .__spec__
93- torch .cuda .BoolTensor = torch .xpu .BoolTensor
94- torch .cuda .CharStorage = torch .xpu .CharStorage
9570 torch .cuda .__file__ = torch .xpu .__file__
96- torch .cuda ._is_in_bad_fork = torch .xpu .lazy_init ._is_in_bad_fork
9771 # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
9872
73+ if legacy :
74+ torch .cuda .os = torch .xpu .os
75+ torch .cuda .Device = torch .xpu .Device
76+ torch .cuda .warnings = torch .xpu .warnings
77+ torch .cuda .classproperty = torch .xpu .classproperty
78+ torch .UntypedStorage .cuda = torch .UntypedStorage .xpu
79+ if float (ipex .__version__ [:3 ]) < 2.3 :
80+ torch .cuda ._initialization_lock = torch .xpu .lazy_init ._initialization_lock
81+ torch .cuda ._initialized = torch .xpu .lazy_init ._initialized
82+ torch .cuda ._is_in_bad_fork = torch .xpu .lazy_init ._is_in_bad_fork
83+ torch .cuda ._lazy_seed_tracker = torch .xpu .lazy_init ._lazy_seed_tracker
84+ torch .cuda ._queued_calls = torch .xpu .lazy_init ._queued_calls
85+ torch .cuda ._tls = torch .xpu .lazy_init ._tls
86+ torch .cuda .threading = torch .xpu .lazy_init .threading
87+ torch .cuda .traceback = torch .xpu .lazy_init .traceback
88+ torch .cuda ._lazy_new = torch .xpu ._lazy_new
89+
90+ torch .cuda .FloatTensor = torch .xpu .FloatTensor
91+ torch .cuda .FloatStorage = torch .xpu .FloatStorage
92+ torch .cuda .BFloat16Tensor = torch .xpu .BFloat16Tensor
93+ torch .cuda .BFloat16Storage = torch .xpu .BFloat16Storage
94+ torch .cuda .HalfTensor = torch .xpu .HalfTensor
95+ torch .cuda .HalfStorage = torch .xpu .HalfStorage
96+ torch .cuda .ByteTensor = torch .xpu .ByteTensor
97+ torch .cuda .ByteStorage = torch .xpu .ByteStorage
98+ torch .cuda .DoubleTensor = torch .xpu .DoubleTensor
99+ torch .cuda .DoubleStorage = torch .xpu .DoubleStorage
100+ torch .cuda .ShortTensor = torch .xpu .ShortTensor
101+ torch .cuda .ShortStorage = torch .xpu .ShortStorage
102+ torch .cuda .LongTensor = torch .xpu .LongTensor
103+ torch .cuda .LongStorage = torch .xpu .LongStorage
104+ torch .cuda .IntTensor = torch .xpu .IntTensor
105+ torch .cuda .IntStorage = torch .xpu .IntStorage
106+ torch .cuda .CharTensor = torch .xpu .CharTensor
107+ torch .cuda .CharStorage = torch .xpu .CharStorage
108+ torch .cuda .BoolTensor = torch .xpu .BoolTensor
109+ torch .cuda .BoolStorage = torch .xpu .BoolStorage
110+ torch .cuda .ComplexFloatStorage = torch .xpu .ComplexFloatStorage
111+ torch .cuda .ComplexDoubleStorage = torch .xpu .ComplexDoubleStorage
112+
113+ if not legacy or float (ipex .__version__ [:3 ]) >= 2.3 :
114+ torch .cuda ._initialization_lock = torch .xpu ._initialization_lock
115+ torch .cuda ._initialized = torch .xpu ._initialized
116+ torch .cuda ._is_in_bad_fork = torch .xpu ._is_in_bad_fork
117+ torch .cuda ._lazy_seed_tracker = torch .xpu ._lazy_seed_tracker
118+ torch .cuda ._queued_calls = torch .xpu ._queued_calls
119+ torch .cuda ._tls = torch .xpu ._tls
120+ torch .cuda .threading = torch .xpu .threading
121+ torch .cuda .traceback = torch .xpu .traceback
122+
99123 # Memory:
100- torch .cuda .memory = torch .xpu .memory
101124 if 'linux' in sys .platform and "WSL2" in os .popen ("uname -a" ).read ():
102125 torch .xpu .empty_cache = lambda : None
103126 torch .cuda .empty_cache = torch .xpu .empty_cache
127+
128+ if legacy :
129+ torch .cuda .memory_summary = torch .xpu .memory_summary
130+ torch .cuda .memory_snapshot = torch .xpu .memory_snapshot
131+ torch .cuda .memory = torch .xpu .memory
104132 torch .cuda .memory_stats = torch .xpu .memory_stats
105- torch .cuda .memory_summary = torch .xpu .memory_summary
106- torch .cuda .memory_snapshot = torch .xpu .memory_snapshot
107133 torch .cuda .memory_allocated = torch .xpu .memory_allocated
108134 torch .cuda .max_memory_allocated = torch .xpu .max_memory_allocated
109135 torch .cuda .memory_reserved = torch .xpu .memory_reserved
@@ -128,52 +154,64 @@ def ipex_init(): # pylint: disable=too-many-statements
128154 torch .cuda .initial_seed = torch .xpu .initial_seed
129155
130156 # AMP:
131- torch .cuda .amp = torch .xpu .amp
132- torch .is_autocast_enabled = torch .xpu .is_autocast_xpu_enabled
133- torch .get_autocast_gpu_dtype = torch .xpu .get_autocast_xpu_dtype
157+ if legacy :
158+ torch .xpu .amp .custom_fwd = torch .cuda .amp .custom_fwd
159+ torch .xpu .amp .custom_bwd = torch .cuda .amp .custom_bwd
160+ torch .cuda .amp = torch .xpu .amp
161+ if float (ipex .__version__ [:3 ]) < 2.3 :
162+ torch .is_autocast_enabled = torch .xpu .is_autocast_xpu_enabled
163+ torch .get_autocast_gpu_dtype = torch .xpu .get_autocast_xpu_dtype
134164
135- if not hasattr (torch .cuda .amp , "common" ):
136- torch .cuda .amp .common = contextlib .nullcontext ()
137- torch .cuda .amp .common .amp_definitely_not_available = lambda : False
165+ if not hasattr (torch .cuda .amp , "common" ):
166+ torch .cuda .amp .common = contextlib .nullcontext ()
167+ torch .cuda .amp .common .amp_definitely_not_available = lambda : False
138168
139- try :
140- torch .cuda .amp .GradScaler = torch .xpu .amp .GradScaler
141- except Exception : # pylint: disable=broad-exception-caught
142169 try :
143- from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
144- gradscaler_init ()
145170 torch .cuda .amp .GradScaler = torch .xpu .amp .GradScaler
146171 except Exception : # pylint: disable=broad-exception-caught
147- torch .cuda .amp .GradScaler = ipex .cpu .autocast ._grad_scaler .GradScaler
172+ try :
173+ from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
174+ gradscaler_init ()
175+ torch .cuda .amp .GradScaler = torch .xpu .amp .GradScaler
176+ except Exception : # pylint: disable=broad-exception-caught
177+ torch .cuda .amp .GradScaler = ipex .cpu .autocast ._grad_scaler .GradScaler
148178
149179 # C
150- torch ._C ._cuda_getCurrentRawStream = ipex ._C ._getCurrentStream
151- ipex ._C ._DeviceProperties .multi_processor_count = ipex ._C ._DeviceProperties .gpu_subslice_count
152- ipex ._C ._DeviceProperties .major = 2024
153- ipex ._C ._DeviceProperties .minor = 0
180+ if legacy and float (ipex .__version__ [:3 ]) < 2.3 :
181+ torch ._C ._cuda_getCurrentRawStream = ipex ._C ._getCurrentRawStream
182+ ipex ._C ._DeviceProperties .multi_processor_count = ipex ._C ._DeviceProperties .gpu_subslice_count
183+ ipex ._C ._DeviceProperties .major = 12
184+ ipex ._C ._DeviceProperties .minor = 1
185+ else :
186+ torch ._C ._cuda_getCurrentRawStream = torch ._C ._xpu_getCurrentRawStream
187+ torch ._C ._XpuDeviceProperties .multi_processor_count = torch ._C ._XpuDeviceProperties .gpu_subslice_count
188+ torch ._C ._XpuDeviceProperties .major = 12
189+ torch ._C ._XpuDeviceProperties .minor = 1
154190
155191 # Fix functions with ipex:
156- torch .cuda .mem_get_info = lambda device = None : [(torch .xpu .get_device_properties (device ).total_memory - torch .xpu .memory_reserved (device )), torch .xpu .get_device_properties (device ).total_memory ]
192+ # torch.xpu.mem_get_info always returns the total memory as free memory
193+ torch .xpu .mem_get_info = lambda device = None : [(torch .xpu .get_device_properties (device ).total_memory - torch .xpu .memory_reserved (device )), torch .xpu .get_device_properties (device ).total_memory ]
194+ torch .cuda .mem_get_info = torch .xpu .mem_get_info
157195 torch ._utils ._get_available_device_type = lambda : "xpu"
158196 torch .has_cuda = True
159197 torch .cuda .has_half = True
160198 torch .cuda .is_bf16_supported = lambda * args , ** kwargs : True
161199 torch .cuda .is_fp16_supported = lambda * args , ** kwargs : True
162200 torch .backends .cuda .is_built = lambda * args , ** kwargs : True
163201 torch .version .cuda = "12.1"
164- torch .cuda .get_device_capability = lambda * args , ** kwargs : [12 ,1 ]
202+ torch .cuda .get_arch_list = lambda : ["ats-m150" , "pvc" ]
203+ torch .cuda .get_device_capability = lambda * args , ** kwargs : (12 ,1 )
165204 torch .cuda .get_device_properties .major = 12
166205 torch .cuda .get_device_properties .minor = 1
167206 torch .cuda .ipc_collect = lambda * args , ** kwargs : None
168207 torch .cuda .utilization = lambda * args , ** kwargs : 0
169208
170- ipex_hijacks ()
171- if not torch .xpu .has_fp64_dtype () or os .environ .get ('IPEX_FORCE_ATTENTION_SLICE' , None ) is not None :
172- try :
173- from .diffusers import ipex_diffusers
174- ipex_diffusers ()
175- except Exception : # pylint: disable=broad-exception-caught
176- pass
209+ device_supports_fp64 , can_allocate_plus_4gb = ipex_hijacks (legacy = legacy )
210+ try :
211+ from .diffusers import ipex_diffusers
212+ ipex_diffusers (device_supports_fp64 = device_supports_fp64 , can_allocate_plus_4gb = can_allocate_plus_4gb )
213+ except Exception : # pylint: disable=broad-exception-caught
214+ pass
177215 torch .cuda .is_xpu_hijacked = True
178216 except Exception as e :
179217 return False , e
0 commit comments