Skip to content

Commit 7989caa

Browse files
authored
Merge pull request #5449 from OpenMathLib/gemm_batch
[WIP] Add BLAS interface to ?GEMM_BATCH to complement the CBLAS one
2 parents fe5402d + 99c077a commit 7989caa

File tree

3 files changed

+81
-24
lines changed

3 files changed

+81
-24
lines changed

interface/CMakeLists.txt

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,8 @@ foreach (CBLAS_FLAG ${CBLAS_FLAGS})
124124
#sdsdot, dsdot
125125
if (BUILD_SINGLE OR BUILD_DOUBLE)
126126
GenerateNamedObjects("sdsdot.c" "" "sdsdot" ${CBLAS_FLAG} "" "" true "SINGLE")
127-
if(CBLAS_FLAG EQUAL 1)
128-
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" false)
129-
endif ()
130-
endif ()
127+
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" false)
128+
endif ()
131129
if (BUILD_DOUBLE)
132130
GenerateNamedObjects("dsdot.c" "" "dsdot" ${CBLAS_FLAG} "" "" true "SINGLE")
133131
endif ()
@@ -162,10 +160,8 @@ if (BUILD_BFLOAT16)
162160
GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16")
163161
GenerateNamedObjects("bf16to.c" "SINGLE_PREC" "sbf16tos" ${CBLAS_FLAG} "" "" true "BFLOAT16")
164162
GenerateNamedObjects("bf16to.c" "DOUBLE_PREC" "dbf16tod" ${CBLAS_FLAG} "" "" true "BFLOAT16")
165-
if(CBLAS_FLAG EQUAL 1)
166163
GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16")
167164
endif ()
168-
endif ()
169165
if (BUILD_HFLOAT16)
170166
GenerateNamedObjects("gemm.c" "" "shgemm" ${CBLAS_FLAG} "" "" true "HFLOAT16")
171167
endif ()
@@ -197,9 +193,7 @@ foreach (float_type ${FLOAT_TYPES})
197193
GenerateNamedObjects("max.c" "USE_ABS" "scamax" ${CBLAS_FLAG} "" "" true "COMPLEX")
198194
GenerateNamedObjects("asum.c" "" "scasum" ${CBLAS_FLAG} "" "" true "COMPLEX")
199195
GenerateNamedObjects("sum.c" "" "scsum" ${CBLAS_FLAG} "" "" true "COMPLEX")
200-
if(CBLAS_FLAG EQUAL 1)
201-
GenerateNamedObjects("gemm_batch.c" "" "cgemm_batch" ${CBLAS_FLAG} "" "" true "COMPLEX")
202-
endif ()
196+
GenerateNamedObjects("gemm_batch.c" "" "cgemm_batch" ${CBLAS_FLAG} "" "" true "COMPLEX")
203197
endif ()
204198
if (${float_type} STREQUAL "ZCOMPLEX")
205199
GenerateNamedObjects("zscal.c" "SSCAL" "dscal" ${CBLAS_FLAG} "" "" false "ZCOMPLEX")
@@ -209,9 +203,7 @@ foreach (float_type ${FLOAT_TYPES})
209203
GenerateNamedObjects("max.c" "USE_ABS" "dzamax" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
210204
GenerateNamedObjects("asum.c" "" "dzasum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
211205
GenerateNamedObjects("sum.c" "" "dzsum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
212-
if(CBLAS_FLAG EQUAL 1)
213-
GenerateNamedObjects("gemm_batch.c" "" "zgemm_batch" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
214-
endif ()
206+
GenerateNamedObjects("gemm_batch.c" "" "zgemm_batch" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
215207
endif ()
216208
endforeach ()
217209

@@ -262,7 +254,7 @@ if ( BUILD_COMPLEX AND NOT BUILD_SINGLE)
262254
GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "SINGLE")
263255
GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "SINGLE")
264256
GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "SINGLE")
265-
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "SINGLE")
257+
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 0 "" "" false "SINGLE")
266258
GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "SINGLE")
267259
GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "SINGLE")
268260
GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "SINGLE")
@@ -276,7 +268,7 @@ if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE)
276268
GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "DOUBLE")
277269
GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "DOUBLE")
278270
GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "DOUBLE")
279-
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "DOUBLE")
271+
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 0 "" "" false "DOUBLE")
280272
GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "DOUBLE")
281273
GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "DOUBLE")
282274
GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "DOUBLE")

interface/Makefile

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,16 @@ SBLAS3OBJS = \
7272
sgemm.$(SUFFIX) ssymm.$(SUFFIX) strmm.$(SUFFIX) \
7373
strsm.$(SUFFIX) ssyrk.$(SUFFIX) ssyr2k.$(SUFFIX) \
7474
somatcopy.$(SUFFIX) simatcopy.$(SUFFIX)\
75-
sgeadd.$(SUFFIX) sgemmt.$(SUFFIX) sgemmtr.$(SUFFIX)
75+
sgeadd.$(SUFFIX) sgemmt.$(SUFFIX) sgemmtr.$(SUFFIX) \
76+
sgemm_batch.$(SUFFIX)
7677

7778
ifeq ($(BUILD_BFLOAT16),1)
7879
BBLAS3OBJS = bgemm.$(SUFFIX)
7980
BBLAS2OBJS = bgemv.$(SUFFIX)
8081
BBLAS1OBJS = bscal.$(SUFFIX)
8182
SBBLAS1OBJS = sbdot.$(SUFFIX)
8283
SBBLAS2OBJS = sbgemv.$(SUFFIX)
83-
SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX)
84+
SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) sbgemm_batch.$(SUFFIX)
8485
SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX)
8586
endif
8687

@@ -111,7 +112,8 @@ DBLAS3OBJS = \
111112
dgemm.$(SUFFIX) dsymm.$(SUFFIX) dtrmm.$(SUFFIX) \
112113
dtrsm.$(SUFFIX) dsyrk.$(SUFFIX) dsyr2k.$(SUFFIX) \
113114
domatcopy.$(SUFFIX) dimatcopy.$(SUFFIX)\
114-
dgeadd.$(SUFFIX) dgemmt.$(SUFFIX) dgemmtr.$(SUFFIX)
115+
dgeadd.$(SUFFIX) dgemmt.$(SUFFIX) dgemmtr.$(SUFFIX) \
116+
dgemm_batch.$(SUFFIX)
115117

116118
CBLAS1OBJS = \
117119
caxpy.$(SUFFIX) caxpyc.$(SUFFIX) cswap.$(SUFFIX) \
@@ -140,7 +142,8 @@ CBLAS3OBJS = \
140142
ctrsm.$(SUFFIX) csyrk.$(SUFFIX) csyr2k.$(SUFFIX) \
141143
chemm.$(SUFFIX) cherk.$(SUFFIX) cher2k.$(SUFFIX) \
142144
comatcopy.$(SUFFIX) cimatcopy.$(SUFFIX)\
143-
cgeadd.$(SUFFIX) cgemmt.$(SUFFIX) cgemmtr.$(SUFFIX)
145+
cgeadd.$(SUFFIX) cgemmt.$(SUFFIX) cgemmtr.$(SUFFIX) \
146+
cgemm_batch.$(SUFFIX)
144147

145148
ZBLAS1OBJS = \
146149
zaxpy.$(SUFFIX) zaxpyc.$(SUFFIX) zswap.$(SUFFIX) \
@@ -169,7 +172,8 @@ ZBLAS3OBJS = \
169172
ztrsm.$(SUFFIX) zsyrk.$(SUFFIX) zsyr2k.$(SUFFIX) \
170173
zhemm.$(SUFFIX) zherk.$(SUFFIX) zher2k.$(SUFFIX) \
171174
zomatcopy.$(SUFFIX) zimatcopy.$(SUFFIX)\
172-
zgeadd.$(SUFFIX) zgemmt.$(SUFFIX) zgemmtr.$(SUFFIX)
175+
zgeadd.$(SUFFIX) zgemmt.$(SUFFIX) zgemmtr.$(SUFFIX) \
176+
zgemm_batch.$(SUFFIX)
173177

174178
ifeq ($(SUPPORT_GEMM3M), 1)
175179

@@ -2539,3 +2543,19 @@ cblas_cgemm_batch.$(SUFFIX) cblas_cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param
25392543

25402544
cblas_zgemm_batch.$(SUFFIX) cblas_zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
25412545
$(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F)
2546+
2547+
sbgemm_batch.$(SUFFIX) sbgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
2548+
$(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F)
2549+
2550+
sgemm_batch.$(SUFFIX) sgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
2551+
$(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F)
2552+
2553+
dgemm_batch.$(SUFFIX) dgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
2554+
$(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F)
2555+
2556+
cgemm_batch.$(SUFFIX) cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
2557+
$(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F)
2558+
2559+
zgemm_batch.$(SUFFIX) zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
2560+
$(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F)
2561+

interface/gemm_batch.c

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ static size_t zgemm_small_kernel_b0[] = {
114114
#endif
115115
#endif
116116

117+
#ifndef CBLAS
118+
void NAME(char *transa_array, char *transb_array,
119+
blasint * m_array, blasint * n_array, blasint * k_array,
120+
FLOAT * alpha_array,
121+
IFLOAT ** a_array, blasint * lda_array,
122+
IFLOAT ** b_array, blasint * ldb_array,
123+
FLOAT * beta_array,
124+
FLOAT ** c_array, blasint * ldc_array, blasint * gcount, blasint * group_size) {
125+
blasint group_count = *gcount;
126+
#else
127+
117128
void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CBLAS_TRANSPOSE * transb_array,
118129
blasint * m_array, blasint * n_array, blasint * k_array,
119130
#ifndef COMPLEX
@@ -134,8 +145,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
134145
FLOAT ** a_array=(FLOAT**)va_array;
135146
FLOAT ** b_array=(FLOAT**)vb_array;
136147
FLOAT ** c_array=(FLOAT**)vc_array;
137-
138148
#endif
149+
#endif
150+
BLASLONG group_m, group_n, group_k;
151+
BLASLONG group_lda, group_ldb, group_ldc;
152+
139153
blas_arg_t * args_array=NULL;
140154

141155
int mode=0, group_mode=0;
@@ -148,8 +162,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
148162
blasint info;
149163

150164
void * group_alpha, * group_beta;
151-
BLASLONG group_m, group_n, group_k;
152-
BLASLONG group_lda, group_ldb, group_ldc;
153165
void * group_routine=NULL;
154166
#ifdef SMALL_MATRIX_OPT
155167
void * group_small_matrix_opt_routine=NULL;
@@ -201,7 +213,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
201213
group_transa = -1;
202214
group_transb = -1;
203215
info = 0;
204-
216+
217+
#if defined(CBLAS)
205218
if (order == CblasColMajor) {
206219
group_m = m_array[i];
207220
group_n = n_array[i];
@@ -254,7 +267,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
254267
group_lda = ldb_array[i];
255268
group_ldb = lda_array[i];
256269
group_ldc = ldc_array[i];
257-
270+
258271
if (transb_array[i] == CblasNoTrans) group_transa = 0;
259272
if (transb_array[i] == CblasTrans) group_transa = 1;
260273
#ifndef COMPLEX
@@ -273,6 +286,32 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
273286
if (transa_array[i] == CblasConjNoTrans) group_transb = 2;
274287
if (transa_array[i] == CblasConjTrans) group_transb = 3;
275288
#endif
289+
290+
#else
291+
group_m = m_array[i];
292+
group_n = n_array[i];
293+
group_k = k_array[i];
294+
295+
group_lda = lda_array[i];
296+
group_ldb = ldb_array[i];
297+
group_ldc = ldc_array[i];
298+
299+
if (transb_array[i] == 'N') group_transa = 0;
300+
if (transb_array[i] == 'T') group_transa = 1;
301+
#ifndef COMPLEX
302+
if (transb_array[i] == 'C') group_transa = 1;
303+
#else
304+
if (transb_array[i] == 'C') group_transa = 3;
305+
#endif
306+
if (transa_array[i] == 'N') group_transb = 0;
307+
if (transa_array[i] == 'T') group_transb = 1;
308+
#ifndef COMPLEX
309+
if (transa_array[i] == 'C') group_transb = 1;
310+
#else
311+
if (transa_array[i] == 'C') group_transb = 3;
312+
#endif
313+
#endif
314+
276315
group_nrowa = group_m;
277316
if (group_transa & 1) group_nrowa = group_k;
278317
group_nrowb = group_k;
@@ -288,7 +327,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
288327
if (group_m < 0) info = 3;
289328
if (group_transb < 0) info = 2;
290329
if (group_transa < 0) info = 1;
330+
#if defined(CBLAS)
291331
}
332+
#endif
292333

293334
if (info >= 0) {
294335
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
@@ -344,13 +385,17 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
344385
args_array[count].alpha=group_alpha;
345386
args_array[count].beta=group_beta;
346387

388+
#if defined(CBLAS)
347389
if (order == CblasColMajor) {
348390
args_array[count].a=(a_array[matrix_idx+j]);
349391
args_array[count].b=(b_array[matrix_idx+j]);
350392
}else if(order == CblasRowMajor){
393+
#endif
351394
args_array[count].a=(b_array[matrix_idx+j]);
352395
args_array[count].b=(a_array[matrix_idx+j]);
396+
#if defined(CBLAS)
353397
}
398+
#endif
354399

355400
args_array[count].c=(c_array[matrix_idx+j]);
356401

0 commit comments

Comments
 (0)