@@ -874,3 +874,95 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
874874                    break 
875875
876876    return  np.asarray(W), gap, tol, n_iter +  1 
877+ 
878+ @ cython.boundscheck (False )
879+ @ cython.wraparound (False )
880+ @ cython.cdivision (True )
881+ def  enet_coordinate_descent_complex (floating[::1, :] W , floating l1_reg ,
882+                                     floating l2_reg ,
883+                                     floating[::1, :] Xr ,
884+                                     floating[::1, :] Xi ,
885+                                     floating[::1, :] Y ,
886+                                     int max_iter , floating tol , object rng ,
887+                                     bint random = 0 ):
888+     """ Cython version of the coordinate descent algorithm
889+         for Elastic-Net mult-task regression in complex domain 
890+ 
891+     """  
892+ 
893+     #  fused types version of BLAS functions
894+     if  floating is  float :
895+         dtype =  np.float32
896+         gemv =  sgemv
897+         dot =  sdot
898+         copy =  scopy
899+     else :
900+         dtype =  np.float64
901+         gemv =  dgemv
902+         dot =  ddot
903+         copy =  dcopy
904+ 
905+     #  get the data information into easy vars
906+     cdef unsigned  int  n_samples =  Xr.shape[0 ]
907+     cdef unsigned  int  n_features =  Xr.shape[1 ]
908+ 
909+     #  initial value of the residuals
910+     cdef floating[::1 ] Rr =  np.empty(n_samples, dtype = dtype)
911+     cdef floating[::1 ] Ri =  np.empty(n_samples, dtype = dtype)
912+ 
913+     cdef floating[:] w_ii =  np.zeros(2 , dtype = dtype)
914+     cdef unsigned  int  ii
915+     cdef unsigned  int  jj
916+     cdef unsigned  int  n_iter =  0 
917+     cdef unsigned  int  f_iter
918+     cdef UINT32_t rand_r_state_seed =  rng.randint(0 , RAND_R_MAX)
919+     cdef UINT32_t*  rand_r_state =  & rand_r_state_seed
920+ 
921+     cdef floating*  W_ptr =  & W[0 , 0 ]
922+     cdef floating*  Y_ptr =  & Y[0 , 0 ]
923+ 
924+     if  l1_reg ==  0 :
925+         warnings.warn(" Coordinate descent with l1_reg=0 may lead to unexpected" 
926+             "  results and is discouraged."  )
927+ 
928+     with  nogil:
929+         #  Compute Rr and Ri: real and imaginary parts of the residual
930+         #  real part: Yr - np.dot(Xr, Wr) + np.dot(Xi, Wi)
931+         copy(n_samples, Y_ptr, 1 , & Rr[0 ], 1 )
932+         gemv(CblasColMajor, CblasNoTrans,
933+              n_samples, n_features, - 1.0 , & Xr[0 , 0 ], n_samples,
934+              W_ptr, 2 , 1.0 , & Rr[0 ], 1 )
935+         gemv(CblasColMajor, CblasNoTrans,
936+              n_samples, n_features, 1.0 , & Xi[0 , 0 ], n_samples,
937+              W_ptr +  1 , 2 , 1.0 , & Rr[0 ], 1 )
938+ 
939+         #  imaginary part:
940+         #  real part: Yr - np.dot(Xr, Wi) - np.dot(Xi, Wr)
941+         copy(n_samples, Y_ptr +  n_samples, 1 , & Ri[0 ], 1 )
942+         gemv(CblasColMajor, CblasNoTrans,
943+              n_samples, n_features, - 1.0 , & Xr[0 , 0 ], n_samples,
944+              W_ptr +  1 , 2 , 1.0 , & Ri[0 ], 1 )
945+         gemv(CblasColMajor, CblasNoTrans,
946+              n_samples, n_features, - 1.0 , & Xi[0 , 0 ], n_samples,
947+              W_ptr, 2 , 1.0 , & Ri[0 ], 1 )
948+ 
949+         #  tol = tol * linalg.norm(Y, ord='fro') ** 2
950+         tol =  tol *  dot(n_samples *  2 , Y_ptr, 1 , Y_ptr, 1 )
951+ 
952+         for  n_iter in  range (max_iter):
953+             for  f_iter in  range (n_features):  #  Loop over coordinates
954+                 #  select a coordinate
955+                 if  random:
956+                     ii =  rand_int(n_features, rand_r_state)
957+                 else :
958+                     ii =  f_iter
959+ 
960+                 #  w_ii = W[:, ii] # Store previous value
961+                 w_ii[0 ] =  W[0 , ii]
962+                 w_ii[1 ] =  W[1 , ii]
963+ 
964+                 #  if np.sum(w_ii ** 2) != 0.0:  # can do better
965+                 if  w_ii[0 ] !=  0.0  or  w_ii[1 ] !=  0.0 :
966+                     #  Remove contributions of w_ii from R
967+ 
968+                 #  prepare for the soft-thresholding
0 commit comments