1919# since the model classes inherit torch.nn.Module.
2020import math
2121
22+ import numba
2223import numpy as np
2324import torch
2425from torch .autograd import Function
2526from torch .nn import functional as F
26- import numba
2727
2828from neural_compressor .torch .utils import accelerator , logger
2929
@@ -301,11 +301,11 @@ def unpack_tensor_with_torch(self, packed_tensor):
301301 unpacked_tensor [:, index ].copy_ (tmp .type (target_dtype ))
302302 accelerator .synchronize ()
303303 return unpacked_tensor
304-
304+
305305 @staticmethod
306306 @numba .jit (nopython = True , parallel = True )
307307 def pack_array_with_numba_b4_c32 (
308- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
308+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
309309 ) -> np .ndarray :
310310 for i in range (new_in_features ):
311311 packed_array [:, i ] = (
@@ -319,11 +319,11 @@ def pack_array_with_numba_b4_c32(
319319 | (raw_array [:, i * n_pack ] & 0b1111 )
320320 )
321321 return packed_array
322-
322+
323323 @staticmethod
324324 @numba .jit (nopython = True , parallel = True )
325325 def pack_array_with_numba_b4_c16 (
326- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
326+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
327327 ) -> np .ndarray :
328328 for i in range (new_in_features ):
329329 packed_array [:, i ] = (
@@ -333,23 +333,20 @@ def pack_array_with_numba_b4_c16(
333333 | (raw_array [:, i * n_pack ] & 0b1111 )
334334 )
335335 return packed_array
336-
336+
337337 @staticmethod
338338 @numba .jit (nopython = True , parallel = True )
339339 def pack_array_with_numba_b4_c8 (
340- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
340+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
341341 ) -> np .ndarray :
342342 for i in range (new_in_features ):
343- packed_array [:, i ] = (
344- ((raw_array [:, i * n_pack + 1 ] & 0b1111 ) << 4 )
345- | (raw_array [:, i * n_pack ] & 0b1111 )
346- )
343+ packed_array [:, i ] = ((raw_array [:, i * n_pack + 1 ] & 0b1111 ) << 4 ) | (raw_array [:, i * n_pack ] & 0b1111 )
347344 return packed_array
348-
345+
349346 @staticmethod
350347 @numba .jit (nopython = True , parallel = True )
351348 def pack_array_with_numba_b4_c64 (
352- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
349+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
353350 ) -> np .ndarray :
354351 for i in range (new_in_features ):
355352 packed_array [:, i ] = (
@@ -372,11 +369,10 @@ def pack_array_with_numba_b4_c64(
372369 )
373370 return packed_array
374371
375-
376372 @staticmethod
377373 @numba .jit (nopython = True , parallel = True )
378374 def pack_array_with_numba_b8_c32 (
379- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
375+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
380376 ) -> np .ndarray :
381377 for i in range (new_in_features ):
382378 packed_array [:, i ] = (
@@ -386,11 +382,11 @@ def pack_array_with_numba_b8_c32(
386382 | (raw_array [:, i * n_pack ] & 0b11111111 )
387383 )
388384 return packed_array
389-
385+
390386 @staticmethod
391387 @numba .jit (nopython = True , parallel = True )
392388 def pack_array_with_numba_b8_c16 (
393- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
389+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
394390 ) -> np .ndarray :
395391 for i in range (new_in_features ):
396392 packed_array [:, i ] = (
@@ -400,20 +396,20 @@ def pack_array_with_numba_b8_c16(
400396 | (raw_array [:, i * n_pack ] & 0b11111111 )
401397 )
402398 return packed_array
403-
399+
404400 @staticmethod
405401 @numba .jit (nopython = True , parallel = True )
406402 def pack_array_with_numba_b8_c8 (
407- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
403+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
408404 ) -> np .ndarray :
409405 for i in range (new_in_features ):
410- packed_array [:, i ] = ( raw_array [:, i * n_pack ] & 0b11111111 )
406+ packed_array [:, i ] = raw_array [:, i * n_pack ] & 0b11111111
411407 return packed_array
412-
408+
413409 @staticmethod
414410 @numba .jit (nopython = True , parallel = True )
415411 def pack_array_with_numba_b8_c64 (
416- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
412+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
417413 ) -> np .ndarray :
418414 for i in range (new_in_features ):
419415 packed_array [:, i ] = (
@@ -427,11 +423,11 @@ def pack_array_with_numba_b8_c64(
427423 | (raw_array [:, i * n_pack ] & 0b11111111 )
428424 )
429425 return packed_array
430-
426+
431427 @staticmethod
432428 @numba .jit (nopython = True , parallel = True )
433429 def pack_array_with_numba_b2_c32 (
434- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
430+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
435431 ) -> np .ndarray :
436432 for i in range (new_in_features ):
437433 packed_array [:, i ] = (
@@ -457,7 +453,7 @@ def pack_array_with_numba_b2_c32(
457453 @staticmethod
458454 @numba .jit (nopython = True , parallel = True )
459455 def pack_array_with_numba_b2_c16 (
460- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
456+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
461457 ) -> np .ndarray :
462458 for i in range (new_in_features ):
463459 packed_array [:, i ] = (
@@ -471,11 +467,11 @@ def pack_array_with_numba_b2_c16(
471467 | (raw_array [:, i * n_pack ] & 0b11 )
472468 )
473469 return packed_array
474-
470+
475471 @staticmethod
476472 @numba .jit (nopython = True , parallel = True )
477473 def pack_array_with_numba_b2_c8 (
478- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
474+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
479475 ) -> np .ndarray :
480476 for i in range (new_in_features ):
481477 packed_array [:, i ] = (
@@ -485,11 +481,11 @@ def pack_array_with_numba_b2_c8(
485481 | (raw_array [:, i * n_pack ] & 0b11 )
486482 )
487483 return packed_array
488-
484+
489485 @staticmethod
490486 @numba .jit (nopython = True , parallel = True )
491487 def pack_array_with_numba_b2_c64 (
492- raw_array : np .ndarray , packed_array :np .ndarray , n_pack : int , new_in_features :int
488+ raw_array : np .ndarray , packed_array : np .ndarray , n_pack : int , new_in_features : int
493489 ) -> np .ndarray :
494490 for i in range (new_in_features ):
495491 packed_array [:, i ] = (
@@ -527,7 +523,7 @@ def pack_array_with_numba_b2_c64(
527523 | (raw_array [:, i * n_pack ] & 0b11 )
528524 )
529525 return packed_array
530-
526+
531527 def pack_array_with_numba (
532528 self , raw_array : np .ndarray , n_pack : int , bits : int , compress_bits : int , compression_dtype = np .int32
533529 ) -> np .ndarray :
@@ -547,17 +543,18 @@ def pack_array_with_numba(
547543 new_in_features = (in_features + n_pack - 1 ) // n_pack
548544 packed_array = np .zeros ((out_features , new_in_features ), dtype = compression_dtype )
549545 raw_array = raw_array .astype (compression_dtype )
550-
546+
551547 pack_method_name = f"pack_array_with_numba_b{ bits } _c{ compress_bits } "
552548 pack_method = getattr (self , pack_method_name )
553549 return pack_method (raw_array , packed_array , n_pack , new_in_features )
554-
550+
555551 @staticmethod
556552 @numba .jit (nopython = True )
557553 def pack_array_with_numba_yi (
558554 raw_tensor : np .ndarray , n_pack : int , bits : int , compression_dtype = np .int32
559555 ) -> np .ndarray :
560556 """Packs the input tensor by combining elements into a specified bit-width format using NumPy.
557+
561558 Args:
562559 raw_tensor (np.ndarray): The tensor to be packed. Shape: [out_features, in_features] or [1, in_features].
563560 n_pack (int): The number of elements to be packed together.
@@ -575,7 +572,7 @@ def pack_array_with_numba_yi(
575572 for i in range (new_in_features ):
576573 packed_tensor [:, i ] = (
577574 (raw_tensor [:, i * n_pack + 7 ] << 28 )
578- | (raw_tensor [:, i * n_pack + 6 ] << 24 )
575+ | (raw_tensor [:, i * n_pack + 6 ] << 24 )
579576 | (raw_tensor [:, i * n_pack + 5 ] << 20 )
580577 | (raw_tensor [:, i * n_pack + 4 ] << 16 )
581578 | (raw_tensor [:, i * n_pack + 3 ] << 12 )
@@ -585,25 +582,29 @@ def pack_array_with_numba_yi(
585582 )
586583
587584 return packed_tensor
588-
585+
589586 def pack_tensor_with_reshape (self , raw_tensor ):
590587 raw_array = raw_tensor .cpu ().numpy ()
591588 target_len = np .ceil (raw_array .shape [1 ] / self .n_pack ).astype (int )
592589 target_dtype = torch .tensor (0 , dtype = self .compression_dtype ).numpy ().dtype
593590 reshaped = raw_array .reshape (- 1 , self .n_pack )
594591 packed_array = np .zeros (reshaped .shape [0 ], dtype = target_dtype )
595592 for i in range (self .n_pack ):
596- packed_array |= (reshaped [:, i ].astype (target_dtype ) << (self .bits * i ))
597-
598- packed_tensor = torch .from_numpy (packed_array .reshape ((raw_array .shape [0 ], target_len ))).to (device = raw_tensor .device )
593+ packed_array |= reshaped [:, i ].astype (target_dtype ) << (self .bits * i )
594+
595+ packed_tensor = torch .from_numpy (packed_array .reshape ((raw_array .shape [0 ], target_len ))).to (
596+ device = raw_tensor .device
597+ )
599598 return packed_tensor
600599
601600 def pack_tensor_with_numpy (self , raw_tensor ):
602601 if self .bits not in [2 , 4 , 8 ]:
603602 return self .pack_tensor_with_reshape (raw_tensor )
604603 compression_dtype = torch .tensor (0 , dtype = self .compression_dtype ).numpy ().dtype
605604 # packed_array = self.pack_array_with_numba_yi(raw_tensor.cpu().numpy(), self.n_pack, self.bits, compression_dtype)
606- packed_array = self .pack_array_with_numba (raw_tensor .cpu ().numpy (), self .n_pack , self .bits , self .compress_bits , compression_dtype )
605+ packed_array = self .pack_array_with_numba (
606+ raw_tensor .cpu ().numpy (), self .n_pack , self .bits , self .compress_bits , compression_dtype
607+ )
607608 return torch .from_numpy (packed_array ).to (device = raw_tensor .device )
608609
609610 def unpack_tensor_with_numpy (self , packed_tensor ):
0 commit comments