@@ -126,30 +126,30 @@ namespace spartan::math
126126
127127 Matrix normalized;
128128 #if defined(__AVX2__)
129- // Create scale reciprocal vector for faster division
130- __m128 scaleRecip = _mm_set_ps (1 .0f , 1 .0f / scale.z , 1 .0f / scale.y , 1 .0f / scale.x );
129+ // create scale reciprocal vector for faster division (per-row scaling)
130+ __m128 scale_recip = _mm_set_ps (1 .0f , 1 .0f / scale.z , 1 .0f / scale.y , 1 .0f / scale.x );
131131
132- // Load and normalize first three columns using loadu for unaligned safety
132+ // load and normalize first three columns
133133 __m128 col0 = _mm_loadu_ps (Data () + 0 );
134134 __m128 col1 = _mm_loadu_ps (Data () + 4 );
135135 __m128 col2 = _mm_loadu_ps (Data () + 8 );
136136
137- // Normalize by scale
138- col0 = _mm_mul_ps (col0, scaleRecip );
139- col1 = _mm_mul_ps (col1, scaleRecip );
140- col2 = _mm_mul_ps (col2, scaleRecip );
137+ // normalize by scale (each row element divided by its row's scale)
138+ col0 = _mm_mul_ps (col0, scale_recip );
139+ col1 = _mm_mul_ps (col1, scale_recip );
140+ col2 = _mm_mul_ps (col2, scale_recip );
141141
142- // Store in normalized matrix
142+ // store in normalized matrix
143143 _mm_storeu_ps (&normalized.m00 , col0);
144144 _mm_storeu_ps (&normalized.m01 , col1);
145145 _mm_storeu_ps (&normalized.m02 , col2);
146146
147- // Set translation components to zero (last row except m33)
147+ // set translation components to zero (last row except m33)
148148 normalized.m30 = 0 .0f ;
149149 normalized.m31 = 0 .0f ;
150150 normalized.m32 = 0 .0f ;
151151
152- // Set last column
152+ // set last column
153153 normalized.m03 = 0 .0f ;
154154 normalized.m13 = 0 .0f ;
155155 normalized.m23 = 0 .0f ;
@@ -225,48 +225,31 @@ namespace spartan::math
225225 #if defined(__AVX2__)
226226 const float * data = Data ();
227227
228- // Calculate signs (scalar)
228+ // calculate signs (scalar)
229229 float xs = (sign (m00 * m01 * m02 * m03) < 0 ) ? -1 .0f : 1 .0f ;
230230 float ys = (sign (m10 * m11 * m12 * m13) < 0 ) ? -1 .0f : 1 .0f ;
231231 float zs = (sign (m20 * m21 * m22 * m23) < 0 ) ? -1 .0f : 1 .0f ;
232232
233- // Define gather indices for rows (in float offsets)
234- __m128i idx0 = _mm_set_epi32 (12 , 8 , 4 , 0 ); // m03, m02, m01, m00 (m03 typically 0 )
235- __m128i idx1 = _mm_set_epi32 (13 , 9 , 5 , 1 ); // m13, m12, m11, m10 (m13 typically 0 )
236- __m128i idx2 = _mm_set_epi32 (14 , 10 , 6 , 2 ); // m23, m22, m21, m20 (m23 typically 0 )
233+ // define gather indices for rows (in float offsets) - only first 3 elements matter
234+ __m128i idx0 = _mm_set_epi32 (0 , 8 , 4 , 0 ); // m00, m01, m02 (last ignored )
235+ __m128i idx1 = _mm_set_epi32 (1 , 9 , 5 , 1 ); // m10, m11, m12 (last ignored )
236+ __m128i idx2 = _mm_set_epi32 (2 , 10 , 6 , 2 ); // m20, m21, m22 (last ignored )
237237
238- // Gather rows using AVX2 gather
238+ // gather rows using avx2 gather
239239 __m128 row0 = _mm_i32gather_ps (data, idx0, 4 );
240240 __m128 row1 = _mm_i32gather_ps (data, idx1, 4 );
241241 __m128 row2 = _mm_i32gather_ps (data, idx2, 4 );
242242
243- // Square each component
244- __m128 square0 = _mm_mul_ps (row0, row0);
245- __m128 square1 = _mm_mul_ps (row1, row1);
246- __m128 square2 = _mm_mul_ps (row2, row2);
243+ // use dot product to sum squares of first 3 elements (mask 0x71 = sum xyz, result in lowest)
244+ __m128 len_sq0 = _mm_dp_ps (row0, row0, 0x71 );
245+ __m128 len_sq1 = _mm_dp_ps (row1, row1, 0x71 );
246+ __m128 len_sq2 = _mm_dp_ps (row2, row2, 0x71 );
247247
248- // Sum using permute and add (avoid hadd for better performance)
249- // For square0: sum all four elements (though fourth is 0^2)
250- __m128 shuf0 = _mm_permute_ps (square0, _MM_SHUFFLE (2 , 3 , 0 , 1 ));
251- __m128 sums0 = _mm_add_ps (square0, shuf0);
252- shuf0 = _mm_permute_ps (sums0, _MM_SHUFFLE (1 , 0 , 3 , 2 ));
253- sums0 = _mm_add_ps (sums0, shuf0);
254-
255- __m128 shuf1 = _mm_permute_ps (square1, _MM_SHUFFLE (2 , 3 , 0 , 1 ));
256- __m128 sums1 = _mm_add_ps (square1, shuf1);
257- shuf1 = _mm_permute_ps (sums1, _MM_SHUFFLE (1 , 0 , 3 , 2 ));
258- sums1 = _mm_add_ps (sums1, shuf1);
259-
260- __m128 shuf2 = _mm_permute_ps (square2, _MM_SHUFFLE (2 , 3 , 0 , 1 ));
261- __m128 sums2 = _mm_add_ps (square2, shuf2);
262- shuf2 = _mm_permute_ps (sums2, _MM_SHUFFLE (1 , 0 , 3 , 2 ));
263- sums2 = _mm_add_ps (sums2, shuf2);
264-
265- // Extract sums and compute sqrt with signs
248+ // extract sums and compute sqrt with signs
266249 return Vector3 (
267- xs * sqrt (_mm_cvtss_f32 (sums0 )),
268- ys * sqrt (_mm_cvtss_f32 (sums1 )),
269- zs * sqrt (_mm_cvtss_f32 (sums2 ))
250+ xs * sqrt (_mm_cvtss_f32 (len_sq0 )),
251+ ys * sqrt (_mm_cvtss_f32 (len_sq1 )),
252+ zs * sqrt (_mm_cvtss_f32 (len_sq2 ))
270253 );
271254 #else
272255 const int xs = (sign (m00 * m01 * m02 * m03) < 0 ) ? -1 : 1 ;
@@ -344,12 +327,39 @@ namespace spartan::math
344327 void Transpose () { *this = Transpose (*this ); }
345328 static Matrix Transpose (const Matrix& matrix)
346329 {
330+ #if defined(__AVX2__)
331+ // load all 4 columns
332+ __m128 col0 = _mm_loadu_ps (&matrix.m00 );
333+ __m128 col1 = _mm_loadu_ps (&matrix.m01 );
334+ __m128 col2 = _mm_loadu_ps (&matrix.m02 );
335+ __m128 col3 = _mm_loadu_ps (&matrix.m03 );
336+
337+ // transpose using unpack operations
338+ __m128 tmp0 = _mm_unpacklo_ps (col0, col1); // m00, m01, m10, m11
339+ __m128 tmp1 = _mm_unpackhi_ps (col0, col1); // m20, m21, m30, m31
340+ __m128 tmp2 = _mm_unpacklo_ps (col2, col3); // m02, m03, m12, m13
341+ __m128 tmp3 = _mm_unpackhi_ps (col2, col3); // m22, m23, m32, m33
342+
343+ // final shuffle to get transposed rows
344+ __m128 row0 = _mm_movelh_ps (tmp0, tmp2); // m00, m01, m02, m03
345+ __m128 row1 = _mm_movehl_ps (tmp2, tmp0); // m10, m11, m12, m13
346+ __m128 row2 = _mm_movelh_ps (tmp1, tmp3); // m20, m21, m22, m23
347+ __m128 row3 = _mm_movehl_ps (tmp3, tmp1); // m30, m31, m32, m33
348+
349+ Matrix result;
350+ _mm_storeu_ps (&result.m00 , row0);
351+ _mm_storeu_ps (&result.m01 , row1);
352+ _mm_storeu_ps (&result.m02 , row2);
353+ _mm_storeu_ps (&result.m03 , row3);
354+ return result;
355+ #else
347356 return Matrix (
348357 matrix.m00 , matrix.m10 , matrix.m20 , matrix.m30 ,
349358 matrix.m01 , matrix.m11 , matrix.m21 , matrix.m31 ,
350359 matrix.m02 , matrix.m12 , matrix.m22 , matrix.m32 ,
351360 matrix.m03 , matrix.m13 , matrix.m23 , matrix.m33
352361 );
362+ #endif
353363 }
354364
355365 [[nodiscard]] Matrix Inverted () const { return Invert (*this ); }
@@ -436,27 +446,27 @@ namespace spartan::math
436446 const float * left_data = Data ();
437447 const float * right_data = rhs.Data ();
438448
439- // Load left columns
449+ // load left columns
440450 __m128 left_col0 = _mm_loadu_ps (left_data + 0 );
441451 __m128 left_col1 = _mm_loadu_ps (left_data + 4 );
442452 __m128 left_col2 = _mm_loadu_ps (left_data + 8 );
443453 __m128 left_col3 = _mm_loadu_ps (left_data + 12 );
444454
445455 for (int j = 0 ; j < 4 ; ++j)
446456 {
447- // Broadcast elements of right column j using AVX broadcast
457+ // broadcast elements of right column j
448458 __m128 v0 = _mm_broadcast_ss (right_data + 4 * j + 0 );
449459 __m128 v1 = _mm_broadcast_ss (right_data + 4 * j + 1 );
450460 __m128 v2 = _mm_broadcast_ss (right_data + 4 * j + 2 );
451461 __m128 v3 = _mm_broadcast_ss (right_data + 4 * j + 3 );
452462
453- // Compute using FMA for AVX2
463+ // compute using fma
454464 __m128 res = _mm_mul_ps (left_col0, v0);
455465 res = _mm_fmadd_ps (left_col1, v1, res);
456466 res = _mm_fmadd_ps (left_col2, v2, res);
457467 res = _mm_fmadd_ps (left_col3, v3, res);
458468
459- // Store result column
469+ // store result column
460470 _mm_storeu_ps (const_cast <float *>(result.Data ()) + 4 * j, res);
461471 }
462472
@@ -487,6 +497,49 @@ namespace spartan::math
487497
488498 Vector3 operator *(const Vector3& rhs) const
489499 {
500+ #if defined(__AVX2__)
501+ // load matrix columns
502+ __m128 col0 = _mm_loadu_ps (&m00);
503+ __m128 col1 = _mm_loadu_ps (&m01);
504+ __m128 col2 = _mm_loadu_ps (&m02);
505+ __m128 col3 = _mm_loadu_ps (&m03);
506+
507+ // transpose columns to rows for correct row-vector multiplication
508+ __m128 tmp0 = _mm_unpacklo_ps (col0, col1);
509+ __m128 tmp1 = _mm_unpackhi_ps (col0, col1);
510+ __m128 tmp2 = _mm_unpacklo_ps (col2, col3);
511+ __m128 tmp3 = _mm_unpackhi_ps (col2, col3);
512+
513+ __m128 row0 = _mm_movelh_ps (tmp0, tmp2);
514+ __m128 row1 = _mm_movehl_ps (tmp2, tmp0);
515+ __m128 row2 = _mm_movelh_ps (tmp1, tmp3);
516+ __m128 row3 = _mm_movehl_ps (tmp3, tmp1);
517+
518+ // broadcast vector components
519+ __m128 vx = _mm_set1_ps (rhs.x );
520+ __m128 vy = _mm_set1_ps (rhs.y );
521+ __m128 vz = _mm_set1_ps (rhs.z );
522+
523+ // multiply and accumulate: result = row0*x + row1*y + row2*z + row3*1
524+ __m128 result = _mm_mul_ps (row0, vx);
525+ result = _mm_fmadd_ps (row1, vy, result);
526+ result = _mm_fmadd_ps (row2, vz, result);
527+ result = _mm_add_ps (result, row3);
528+
529+ // extract w for perspective divide
530+ float w = _mm_cvtss_f32 (_mm_shuffle_ps (result, result, _MM_SHUFFLE (3 , 3 , 3 , 3 )));
531+
532+ // perspective divide if needed
533+ if (w != 1 .0f )
534+ {
535+ __m128 inv_w = _mm_set1_ps (1 .0f / w);
536+ result = _mm_mul_ps (result, inv_w);
537+ }
538+
539+ return Vector3 (_mm_cvtss_f32 (result),
540+ _mm_cvtss_f32 (_mm_shuffle_ps (result, result, _MM_SHUFFLE (1 , 1 , 1 , 1 ))),
541+ _mm_cvtss_f32 (_mm_shuffle_ps (result, result, _MM_SHUFFLE (2 , 2 , 2 , 2 ))));
542+ #else
490543 float x = (rhs.x * m00) + (rhs.y * m10) + (rhs.z * m20) + m30;
491544 float y = (rhs.x * m01) + (rhs.y * m11) + (rhs.z * m21) + m31;
492545 float z = (rhs.x * m02) + (rhs.y * m12) + (rhs.z * m22) + m32;
@@ -501,23 +554,75 @@ namespace spartan::math
501554 }
502555
503556 return Vector3 (x, y, z);
557+ #endif
504558 }
505559
506560 Vector4 operator *(const Vector4& rhs) const
507561 {
562+ #if defined(__AVX2__)
563+ // load matrix columns
564+ __m128 col0 = _mm_loadu_ps (&m00);
565+ __m128 col1 = _mm_loadu_ps (&m01);
566+ __m128 col2 = _mm_loadu_ps (&m02);
567+ __m128 col3 = _mm_loadu_ps (&m03);
568+
569+ // transpose columns to rows for correct row-vector multiplication
570+ __m128 tmp0 = _mm_unpacklo_ps (col0, col1);
571+ __m128 tmp1 = _mm_unpackhi_ps (col0, col1);
572+ __m128 tmp2 = _mm_unpacklo_ps (col2, col3);
573+ __m128 tmp3 = _mm_unpackhi_ps (col2, col3);
574+
575+ __m128 row0 = _mm_movelh_ps (tmp0, tmp2);
576+ __m128 row1 = _mm_movehl_ps (tmp2, tmp0);
577+ __m128 row2 = _mm_movelh_ps (tmp1, tmp3);
578+ __m128 row3 = _mm_movehl_ps (tmp3, tmp1);
579+
580+ // broadcast vector components
581+ __m128 vx = _mm_set1_ps (rhs.x );
582+ __m128 vy = _mm_set1_ps (rhs.y );
583+ __m128 vz = _mm_set1_ps (rhs.z );
584+ __m128 vw = _mm_set1_ps (rhs.w );
585+
586+ // multiply and accumulate: result = row0*x + row1*y + row2*z + row3*w
587+ __m128 result = _mm_mul_ps (row0, vx);
588+ result = _mm_fmadd_ps (row1, vy, result);
589+ result = _mm_fmadd_ps (row2, vz, result);
590+ result = _mm_fmadd_ps (row3, vw, result);
591+
592+ return Vector4 (_mm_cvtss_f32 (result),
593+ _mm_cvtss_f32 (_mm_shuffle_ps (result, result, _MM_SHUFFLE (1 , 1 , 1 , 1 ))),
594+ _mm_cvtss_f32 (_mm_shuffle_ps (result, result, _MM_SHUFFLE (2 , 2 , 2 , 2 ))),
595+ _mm_cvtss_f32 (_mm_shuffle_ps (result, result, _MM_SHUFFLE (3 , 3 , 3 , 3 ))));
596+ #else
508597 return Vector4
509598 (
510599 (rhs.x * m00) + (rhs.y * m10) + (rhs.z * m20) + (rhs.w * m30),
511600 (rhs.x * m01) + (rhs.y * m11) + (rhs.z * m21) + (rhs.w * m31),
512601 (rhs.x * m02) + (rhs.y * m12) + (rhs.z * m22) + (rhs.w * m32),
513602 (rhs.x * m03) + (rhs.y * m13) + (rhs.z * m23) + (rhs.w * m33)
514603 );
604+ #endif
515605 }
516606
517607 bool operator ==(const Matrix& rhs) const
518608 {
519- const float * data_left = Data ();
520- const float * data_right = rhs.Data ();
609+ #if defined(__AVX2__)
610+ // load and compare all 4 columns using 256-bit registers for efficiency
611+ __m256 left01 = _mm256_loadu_ps (&m00); // columns 0 and 1
612+ __m256 left23 = _mm256_loadu_ps (&m02); // columns 2 and 3
613+ __m256 right01 = _mm256_loadu_ps (&rhs.m00 );
614+ __m256 right23 = _mm256_loadu_ps (&rhs.m02 );
615+
616+ // compare for equality
617+ __m256 cmp01 = _mm256_cmp_ps (left01, right01, _CMP_EQ_OQ);
618+ __m256 cmp23 = _mm256_cmp_ps (left23, right23, _CMP_EQ_OQ);
619+
620+ // combine results and check all bits are set
621+ __m256 cmp_all = _mm256_and_ps (cmp01, cmp23);
622+ return _mm256_movemask_ps (cmp_all) == 0xFF ;
623+ #else
624+ const float * data_left = Data ();
625+ const float * data_right = rhs.Data ();
521626
522627 for (unsigned i = 0 ; i < 16 ; ++i)
523628 {
@@ -526,12 +631,40 @@ namespace spartan::math
526631 }
527632
528633 return true ;
634+ #endif
529635 }
530636
531637 bool operator !=(const Matrix& rhs) const { return !(*this == rhs); }
532638
533- bool Equals (const Matrix& rhs)
639+ bool Equals (const Matrix& rhs) const
534640 {
641+ #if defined(__AVX2__)
642+ const float eps = std::numeric_limits<float >::epsilon ();
643+ __m256 epsilon = _mm256_set1_ps (eps);
644+
645+ // load columns
646+ __m256 left01 = _mm256_loadu_ps (&m00);
647+ __m256 left23 = _mm256_loadu_ps (&m02);
648+ __m256 right01 = _mm256_loadu_ps (&rhs.m00 );
649+ __m256 right23 = _mm256_loadu_ps (&rhs.m02 );
650+
651+ // compute absolute difference
652+ __m256 diff01 = _mm256_sub_ps (left01, right01);
653+ __m256 diff23 = _mm256_sub_ps (left23, right23);
654+
655+ // absolute value (clear sign bit)
656+ __m256 sign_mask = _mm256_set1_ps (-0 .0f );
657+ diff01 = _mm256_andnot_ps (sign_mask, diff01);
658+ diff23 = _mm256_andnot_ps (sign_mask, diff23);
659+
660+ // compare against epsilon
661+ __m256 cmp01 = _mm256_cmp_ps (diff01, epsilon, _CMP_LE_OQ);
662+ __m256 cmp23 = _mm256_cmp_ps (diff23, epsilon, _CMP_LE_OQ);
663+
664+ // combine and check all pass
665+ __m256 cmp_all = _mm256_and_ps (cmp01, cmp23);
666+ return _mm256_movemask_ps (cmp_all) == 0xFF ;
667+ #else
535668 const float * data_left = Data ();
536669 const float * data_right = rhs.Data ();
537670
@@ -542,6 +675,7 @@ namespace spartan::math
542675 }
543676
544677 return true ;
678+ #endif
545679 }
546680
547681 [[nodiscard]] const float * Data () const { return &m00; }
@@ -556,6 +690,6 @@ namespace spartan::math
556690 };
557691
558692 // reverse order operators
559- inline Vector3 operator *(const Vector3& lhs, const Matrix& rhs) { return rhs * lhs; }
560- inline Vector4 operator *(const Vector4& lhs, const Matrix& rhs) { return rhs * lhs; }
693+ inline Vector3 operator *(const Vector3& lhs, const Matrix& rhs) { return rhs * lhs; }
694+ inline Vector4 operator *(const Vector4& lhs, const Matrix& rhs) { return rhs * lhs; }
561695}
0 commit comments