1
1
/*
2
- // Copyright (c) 2023 Ben Ashbaugh
2
+ // Copyright (c) 2023-2025 Ben Ashbaugh
3
3
//
4
4
// SPDX-License-Identifier: MIT
5
5
*/
@@ -183,6 +183,20 @@ static cl_svm_capabilities_khr getSystemSVMCaps(cl_device_id device)
183
183
return ret;
184
184
}
185
185
186
+ struct SUSMFuncs
187
+ {
188
+ clHostMemAllocINTEL_fn clHostMemAllocINTEL;
189
+ clDeviceMemAllocINTEL_fn clDeviceMemAllocINTEL;
190
+ clSharedMemAllocINTEL_fn clSharedMemAllocINTEL;
191
+ clMemFreeINTEL_fn clMemFreeINTEL;
192
+ clMemBlockingFreeINTEL_fn clMemBlockingFreeINTEL;
193
+ clGetMemAllocInfoINTEL_fn clGetMemAllocInfoINTEL;
194
+ clSetKernelArgMemPointerINTEL_fn clSetKernelArgMemPointerINTEL;
195
+ clEnqueueMemFillINTEL_fn clEnqueueMemFillINTEL;
196
+ clEnqueueMemcpyINTEL_fn clEnqueueMemcpyINTEL;
197
+ clEnqueueMemAdviseINTEL_fn clEnqueueMemAdviseINTEL;
198
+ };
199
+
186
200
struct SAllocInfo
187
201
{
188
202
cl_uint TypeIndex = ~0 ;
@@ -219,6 +233,7 @@ struct SLayerContext
219
233
220
234
for (auto platform: platforms) {
221
235
getSVMTypesForPlatform (platform);
236
+ getUSMFuncsForPlatform (platform);
222
237
}
223
238
}
224
239
@@ -232,6 +247,11 @@ struct SLayerContext
232
247
return TypeCapsDevice[device];
233
248
}
234
249
250
+ const SUSMFuncs& getUSMFuncs (cl_platform_id platform)
251
+ {
252
+ return USMFuncs[platform];
253
+ }
254
+
235
255
bool isKnownAlloc (cl_context context, const void * ptr) const
236
256
{
237
257
if (AllocMaps.find (context) != AllocMaps.end ()) {
@@ -435,9 +455,47 @@ struct SLayerContext
435
455
}
436
456
}
437
457
458
+ void getUSMFuncsForPlatform (cl_platform_id platform)
459
+ {
460
+ SUSMFuncs& funcs = USMFuncs[platform];
461
+
462
+ funcs.clHostMemAllocINTEL = (clHostMemAllocINTEL_fn)
463
+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
464
+ platform, " clHostMemAllocINTEL" );
465
+ funcs.clDeviceMemAllocINTEL = (clDeviceMemAllocINTEL_fn)
466
+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
467
+ platform, " clDeviceMemAllocINTEL" );
468
+ funcs.clSharedMemAllocINTEL = (clSharedMemAllocINTEL_fn)
469
+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
470
+ platform, " clSharedMemAllocINTEL" );
471
+ funcs.clMemFreeINTEL = (clMemFreeINTEL_fn)
472
+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
473
+ platform, " clMemFreeINTEL" );
474
+ funcs.clMemBlockingFreeINTEL = (clMemBlockingFreeINTEL_fn)
475
+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
476
+ platform, " clMemBlockingFreeINTEL" );
477
+ funcs.clGetMemAllocInfoINTEL = (clGetMemAllocInfoINTEL_fn)
478
+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
479
+ platform, " clGetMemAllocInfoINTEL" );
480
+ funcs.clSetKernelArgMemPointerINTEL = (clSetKernelArgMemPointerINTEL_fn)
481
+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
482
+ platform, " clSetKernelArgMemPointerINTEL" );
483
+ funcs.clEnqueueMemFillINTEL = (clEnqueueMemFillINTEL_fn)
484
+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
485
+ platform, " clEnqueueMemFillINTEL" );
486
+ funcs.clEnqueueMemcpyINTEL = (clEnqueueMemcpyINTEL_fn)
487
+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
488
+ platform, " clEnqueueMemcpyINTEL" );
489
+ funcs.clEnqueueMemAdviseINTEL = (clEnqueueMemAdviseINTEL_fn)
490
+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
491
+ platform, " clEnqueueMemAdviseINTEL" );
492
+ }
493
+
438
494
std::map<cl_platform_id, std::vector<cl_svm_capabilities_khr>> TypeCapsPlatform;
439
495
std::map<cl_device_id, std::vector<cl_svm_capabilities_khr>> TypeCapsDevice;
440
496
497
+ std::map<cl_platform_id, SUSMFuncs> USMFuncs;
498
+
441
499
typedef std::map<const void *, SAllocInfo> CAllocMap;
442
500
std::map<cl_context, CAllocMap> AllocMaps;
443
501
@@ -597,6 +655,7 @@ void* CL_API_CALL clSVMAllocWithPropertiesKHR_EMU(
597
655
cl_int* errcode_ret)
598
656
{
599
657
cl_platform_id platform = getPlatform (context);
658
+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
600
659
601
660
const auto & typeCapsPlatform = getLayerContext ().getSVMCaps (platform);
602
661
if (svm_type_index >= typeCapsPlatform.size ()) {
@@ -626,7 +685,7 @@ void* CL_API_CALL clSVMAllocWithPropertiesKHR_EMU(
626
685
const auto caps = typeCapsPlatform[svm_type_index];
627
686
if ((caps & CL_SVM_TYPE_MACRO_DEVICE_KHR) == CL_SVM_TYPE_MACRO_DEVICE_KHR) {
628
687
isUSMPointer = true ;
629
- ret = clDeviceMemAllocINTEL (
688
+ ret = USMFuncs. clDeviceMemAllocINTEL (
630
689
context,
631
690
device,
632
691
nullptr ,
@@ -636,7 +695,7 @@ void* CL_API_CALL clSVMAllocWithPropertiesKHR_EMU(
636
695
}
637
696
else if ((caps & CL_SVM_TYPE_MACRO_HOST_KHR) == CL_SVM_TYPE_MACRO_HOST_KHR) {
638
697
isUSMPointer = true ;
639
- ret = clHostMemAllocINTEL (
698
+ ret = USMFuncs. clHostMemAllocINTEL (
640
699
context,
641
700
nullptr ,
642
701
size,
@@ -653,7 +712,7 @@ void* CL_API_CALL clSVMAllocWithPropertiesKHR_EMU(
653
712
}
654
713
else if ((caps & CL_SVM_TYPE_MACRO_SINGLE_DEVICE_SHARED_KHR) == CL_SVM_TYPE_MACRO_SINGLE_DEVICE_SHARED_KHR) {
655
714
isUSMPointer = true ;
656
- ret = clSharedMemAllocINTEL (
715
+ ret = USMFuncs. clSharedMemAllocINTEL (
657
716
context,
658
717
device,
659
718
nullptr ,
@@ -739,7 +798,9 @@ cl_int CL_API_CALL clSVMFreeWithPropertiesKHR_EMU(
739
798
740
799
cl_int errorCode = CL_SUCCESS;
741
800
if (isUSMPtr (context, ptr)) {
742
- errorCode = clMemBlockingFreeINTEL (
801
+ cl_platform_id platform = getPlatform (context);
802
+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
803
+ errorCode = USMFuncs.clMemBlockingFreeINTEL (
743
804
context,
744
805
ptr);
745
806
} else if (isSVMPtr (context, ptr)) {
@@ -1191,7 +1252,9 @@ cl_int CL_API_CALL clSetKernelArgSVMPointer_override(
1191
1252
cl_context context = getContext (kernel);
1192
1253
1193
1254
if (isUSMPtr (context, arg_value)) {
1194
- return clSetKernelArgMemPointerINTEL (
1255
+ cl_platform_id platform = getPlatform (context);
1256
+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
1257
+ return USMFuncs.clSetKernelArgMemPointerINTEL (
1195
1258
kernel,
1196
1259
arg_index,
1197
1260
arg_value);
@@ -1284,7 +1347,9 @@ void CL_API_CALL clSVMFree_override(
1284
1347
void * ptr)
1285
1348
{
1286
1349
if (isUSMPtr (context, ptr)) {
1287
- clMemFreeINTEL (context, ptr);
1350
+ cl_platform_id platform = getPlatform (context);
1351
+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
1352
+ USMFuncs.clMemFreeINTEL (context, ptr);
1288
1353
} else {
1289
1354
g_pNextDispatch->clSVMFree (context, ptr);
1290
1355
}
@@ -1351,7 +1416,9 @@ cl_int CL_API_CALL clEnqueueSVMMemcpy_override(
1351
1416
}
1352
1417
1353
1418
if (isUSMPtr (context, dst_ptr) || isUSMPtr (context, src_ptr)) {
1354
- cl_int ret = clEnqueueMemcpyINTEL (
1419
+ cl_platform_id platform = getPlatform (context);
1420
+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
1421
+ cl_int ret = USMFuncs.clEnqueueMemcpyINTEL (
1355
1422
command_queue,
1356
1423
blocking_copy,
1357
1424
dst_ptr,
@@ -1407,7 +1474,9 @@ cl_int CL_API_CALL clEnqueueSVMMemFill_override(
1407
1474
}
1408
1475
1409
1476
if (isUSMPtr (context, svm_ptr)) {
1410
- cl_int ret = clEnqueueMemFillINTEL (
1477
+ cl_platform_id platform = getPlatform (context);
1478
+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
1479
+ cl_int ret = USMFuncs.clEnqueueMemFillINTEL (
1411
1480
command_queue,
1412
1481
svm_ptr,
1413
1482
pattern,
0 commit comments