Skip to content

Commit 073f946

Browse files
committed
[math] optimized matrix
1 parent 2b80f59 commit 073f946

1 file changed

Lines changed: 184 additions & 50 deletions

File tree

source/runtime/Math/Matrix.h

Lines changed: 184 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -126,30 +126,30 @@ namespace spartan::math
126126

127127
Matrix normalized;
128128
#if defined(__AVX2__)
129-
// Create scale reciprocal vector for faster division
130-
__m128 scaleRecip = _mm_set_ps(1.0f, 1.0f / scale.z, 1.0f / scale.y, 1.0f / scale.x);
129+
// create scale reciprocal vector for faster division (per-row scaling)
130+
__m128 scale_recip = _mm_set_ps(1.0f, 1.0f / scale.z, 1.0f / scale.y, 1.0f / scale.x);
131131

132-
// Load and normalize first three columns using loadu for unaligned safety
132+
// load and normalize first three columns
133133
__m128 col0 = _mm_loadu_ps(Data() + 0);
134134
__m128 col1 = _mm_loadu_ps(Data() + 4);
135135
__m128 col2 = _mm_loadu_ps(Data() + 8);
136136

137-
// Normalize by scale
138-
col0 = _mm_mul_ps(col0, scaleRecip);
139-
col1 = _mm_mul_ps(col1, scaleRecip);
140-
col2 = _mm_mul_ps(col2, scaleRecip);
137+
// normalize by scale (each row element divided by its row's scale)
138+
col0 = _mm_mul_ps(col0, scale_recip);
139+
col1 = _mm_mul_ps(col1, scale_recip);
140+
col2 = _mm_mul_ps(col2, scale_recip);
141141

142-
// Store in normalized matrix
142+
// store in normalized matrix
143143
_mm_storeu_ps(&normalized.m00, col0);
144144
_mm_storeu_ps(&normalized.m01, col1);
145145
_mm_storeu_ps(&normalized.m02, col2);
146146

147-
// Set translation components to zero (last row except m33)
147+
// set translation components to zero (last row except m33)
148148
normalized.m30 = 0.0f;
149149
normalized.m31 = 0.0f;
150150
normalized.m32 = 0.0f;
151151

152-
// Set last column
152+
// set last column
153153
normalized.m03 = 0.0f;
154154
normalized.m13 = 0.0f;
155155
normalized.m23 = 0.0f;
@@ -225,48 +225,31 @@ namespace spartan::math
225225
#if defined(__AVX2__)
226226
const float* data = Data();
227227

228-
// Calculate signs (scalar)
228+
// calculate signs (scalar)
229229
float xs = (sign(m00 * m01 * m02 * m03) < 0) ? -1.0f : 1.0f;
230230
float ys = (sign(m10 * m11 * m12 * m13) < 0) ? -1.0f : 1.0f;
231231
float zs = (sign(m20 * m21 * m22 * m23) < 0) ? -1.0f : 1.0f;
232232

233-
// Define gather indices for rows (in float offsets)
234-
__m128i idx0 = _mm_set_epi32(12, 8, 4, 0); // m03, m02, m01, m00 (m03 typically 0)
235-
__m128i idx1 = _mm_set_epi32(13, 9, 5, 1); // m13, m12, m11, m10 (m13 typically 0)
236-
__m128i idx2 = _mm_set_epi32(14, 10, 6, 2); // m23, m22, m21, m20 (m23 typically 0)
233+
// define gather indices for rows (in float offsets) - only first 3 elements matter
234+
__m128i idx0 = _mm_set_epi32(0, 8, 4, 0); // m00, m01, m02 (last ignored)
235+
__m128i idx1 = _mm_set_epi32(1, 9, 5, 1); // m10, m11, m12 (last ignored)
236+
__m128i idx2 = _mm_set_epi32(2, 10, 6, 2); // m20, m21, m22 (last ignored)
237237

238-
// Gather rows using AVX2 gather
238+
// gather rows using avx2 gather
239239
__m128 row0 = _mm_i32gather_ps(data, idx0, 4);
240240
__m128 row1 = _mm_i32gather_ps(data, idx1, 4);
241241
__m128 row2 = _mm_i32gather_ps(data, idx2, 4);
242242

243-
// Square each component
244-
__m128 square0 = _mm_mul_ps(row0, row0);
245-
__m128 square1 = _mm_mul_ps(row1, row1);
246-
__m128 square2 = _mm_mul_ps(row2, row2);
243+
// use dot product to sum squares of first 3 elements (mask 0x71 = sum xyz, result in lowest)
244+
__m128 len_sq0 = _mm_dp_ps(row0, row0, 0x71);
245+
__m128 len_sq1 = _mm_dp_ps(row1, row1, 0x71);
246+
__m128 len_sq2 = _mm_dp_ps(row2, row2, 0x71);
247247

248-
// Sum using permute and add (avoid hadd for better performance)
249-
// For square0: sum all four elements (though fourth is 0^2)
250-
__m128 shuf0 = _mm_permute_ps(square0, _MM_SHUFFLE(2, 3, 0, 1));
251-
__m128 sums0 = _mm_add_ps(square0, shuf0);
252-
shuf0 = _mm_permute_ps(sums0, _MM_SHUFFLE(1, 0, 3, 2));
253-
sums0 = _mm_add_ps(sums0, shuf0);
254-
255-
__m128 shuf1 = _mm_permute_ps(square1, _MM_SHUFFLE(2, 3, 0, 1));
256-
__m128 sums1 = _mm_add_ps(square1, shuf1);
257-
shuf1 = _mm_permute_ps(sums1, _MM_SHUFFLE(1, 0, 3, 2));
258-
sums1 = _mm_add_ps(sums1, shuf1);
259-
260-
__m128 shuf2 = _mm_permute_ps(square2, _MM_SHUFFLE(2, 3, 0, 1));
261-
__m128 sums2 = _mm_add_ps(square2, shuf2);
262-
shuf2 = _mm_permute_ps(sums2, _MM_SHUFFLE(1, 0, 3, 2));
263-
sums2 = _mm_add_ps(sums2, shuf2);
264-
265-
// Extract sums and compute sqrt with signs
248+
// extract sums and compute sqrt with signs
266249
return Vector3(
267-
xs * sqrt(_mm_cvtss_f32(sums0)),
268-
ys * sqrt(_mm_cvtss_f32(sums1)),
269-
zs * sqrt(_mm_cvtss_f32(sums2))
250+
xs * sqrt(_mm_cvtss_f32(len_sq0)),
251+
ys * sqrt(_mm_cvtss_f32(len_sq1)),
252+
zs * sqrt(_mm_cvtss_f32(len_sq2))
270253
);
271254
#else
272255
const int xs = (sign(m00 * m01 * m02 * m03) < 0) ? -1 : 1;
@@ -344,12 +327,39 @@ namespace spartan::math
344327
void Transpose() { *this = Transpose(*this); }
345328
static Matrix Transpose(const Matrix& matrix)
346329
{
330+
#if defined(__AVX2__)
331+
// load all 4 columns
332+
__m128 col0 = _mm_loadu_ps(&matrix.m00);
333+
__m128 col1 = _mm_loadu_ps(&matrix.m01);
334+
__m128 col2 = _mm_loadu_ps(&matrix.m02);
335+
__m128 col3 = _mm_loadu_ps(&matrix.m03);
336+
337+
// transpose using unpack operations
338+
__m128 tmp0 = _mm_unpacklo_ps(col0, col1); // m00, m01, m10, m11
339+
__m128 tmp1 = _mm_unpackhi_ps(col0, col1); // m20, m21, m30, m31
340+
__m128 tmp2 = _mm_unpacklo_ps(col2, col3); // m02, m03, m12, m13
341+
__m128 tmp3 = _mm_unpackhi_ps(col2, col3); // m22, m23, m32, m33
342+
343+
// final shuffle to get transposed rows
344+
__m128 row0 = _mm_movelh_ps(tmp0, tmp2); // m00, m01, m02, m03
345+
__m128 row1 = _mm_movehl_ps(tmp2, tmp0); // m10, m11, m12, m13
346+
__m128 row2 = _mm_movelh_ps(tmp1, tmp3); // m20, m21, m22, m23
347+
__m128 row3 = _mm_movehl_ps(tmp3, tmp1); // m30, m31, m32, m33
348+
349+
Matrix result;
350+
_mm_storeu_ps(&result.m00, row0);
351+
_mm_storeu_ps(&result.m01, row1);
352+
_mm_storeu_ps(&result.m02, row2);
353+
_mm_storeu_ps(&result.m03, row3);
354+
return result;
355+
#else
347356
return Matrix(
348357
matrix.m00, matrix.m10, matrix.m20, matrix.m30,
349358
matrix.m01, matrix.m11, matrix.m21, matrix.m31,
350359
matrix.m02, matrix.m12, matrix.m22, matrix.m32,
351360
matrix.m03, matrix.m13, matrix.m23, matrix.m33
352361
);
362+
#endif
353363
}
354364

355365
[[nodiscard]] Matrix Inverted() const { return Invert(*this); }
@@ -436,27 +446,27 @@ namespace spartan::math
436446
const float* left_data = Data();
437447
const float* right_data = rhs.Data();
438448

439-
// Load left columns
449+
// load left columns
440450
__m128 left_col0 = _mm_loadu_ps(left_data + 0);
441451
__m128 left_col1 = _mm_loadu_ps(left_data + 4);
442452
__m128 left_col2 = _mm_loadu_ps(left_data + 8);
443453
__m128 left_col3 = _mm_loadu_ps(left_data + 12);
444454

445455
for (int j = 0; j < 4; ++j)
446456
{
447-
// Broadcast elements of right column j using AVX broadcast
457+
// broadcast elements of right column j
448458
__m128 v0 = _mm_broadcast_ss(right_data + 4 * j + 0);
449459
__m128 v1 = _mm_broadcast_ss(right_data + 4 * j + 1);
450460
__m128 v2 = _mm_broadcast_ss(right_data + 4 * j + 2);
451461
__m128 v3 = _mm_broadcast_ss(right_data + 4 * j + 3);
452462

453-
// Compute using FMA for AVX2
463+
// compute using fma
454464
__m128 res = _mm_mul_ps(left_col0, v0);
455465
res = _mm_fmadd_ps(left_col1, v1, res);
456466
res = _mm_fmadd_ps(left_col2, v2, res);
457467
res = _mm_fmadd_ps(left_col3, v3, res);
458468

459-
// Store result column
469+
// store result column
460470
_mm_storeu_ps(const_cast<float*>(result.Data()) + 4 * j, res);
461471
}
462472

@@ -487,6 +497,49 @@ namespace spartan::math
487497

488498
Vector3 operator*(const Vector3& rhs) const
489499
{
500+
#if defined(__AVX2__)
501+
// load matrix columns
502+
__m128 col0 = _mm_loadu_ps(&m00);
503+
__m128 col1 = _mm_loadu_ps(&m01);
504+
__m128 col2 = _mm_loadu_ps(&m02);
505+
__m128 col3 = _mm_loadu_ps(&m03);
506+
507+
// transpose columns to rows for correct row-vector multiplication
508+
__m128 tmp0 = _mm_unpacklo_ps(col0, col1);
509+
__m128 tmp1 = _mm_unpackhi_ps(col0, col1);
510+
__m128 tmp2 = _mm_unpacklo_ps(col2, col3);
511+
__m128 tmp3 = _mm_unpackhi_ps(col2, col3);
512+
513+
__m128 row0 = _mm_movelh_ps(tmp0, tmp2);
514+
__m128 row1 = _mm_movehl_ps(tmp2, tmp0);
515+
__m128 row2 = _mm_movelh_ps(tmp1, tmp3);
516+
__m128 row3 = _mm_movehl_ps(tmp3, tmp1);
517+
518+
// broadcast vector components
519+
__m128 vx = _mm_set1_ps(rhs.x);
520+
__m128 vy = _mm_set1_ps(rhs.y);
521+
__m128 vz = _mm_set1_ps(rhs.z);
522+
523+
// multiply and accumulate: result = row0*x + row1*y + row2*z + row3*1
524+
__m128 result = _mm_mul_ps(row0, vx);
525+
result = _mm_fmadd_ps(row1, vy, result);
526+
result = _mm_fmadd_ps(row2, vz, result);
527+
result = _mm_add_ps(result, row3);
528+
529+
// extract w for perspective divide
530+
float w = _mm_cvtss_f32(_mm_shuffle_ps(result, result, _MM_SHUFFLE(3, 3, 3, 3)));
531+
532+
// perspective divide if needed
533+
if (w != 1.0f)
534+
{
535+
__m128 inv_w = _mm_set1_ps(1.0f / w);
536+
result = _mm_mul_ps(result, inv_w);
537+
}
538+
539+
return Vector3(_mm_cvtss_f32(result),
540+
_mm_cvtss_f32(_mm_shuffle_ps(result, result, _MM_SHUFFLE(1, 1, 1, 1))),
541+
_mm_cvtss_f32(_mm_shuffle_ps(result, result, _MM_SHUFFLE(2, 2, 2, 2))));
542+
#else
490543
float x = (rhs.x * m00) + (rhs.y * m10) + (rhs.z * m20) + m30;
491544
float y = (rhs.x * m01) + (rhs.y * m11) + (rhs.z * m21) + m31;
492545
float z = (rhs.x * m02) + (rhs.y * m12) + (rhs.z * m22) + m32;
@@ -501,23 +554,75 @@ namespace spartan::math
501554
}
502555

503556
return Vector3(x, y, z);
557+
#endif
504558
}
505559

506560
Vector4 operator*(const Vector4& rhs) const
507561
{
562+
#if defined(__AVX2__)
563+
// load matrix columns
564+
__m128 col0 = _mm_loadu_ps(&m00);
565+
__m128 col1 = _mm_loadu_ps(&m01);
566+
__m128 col2 = _mm_loadu_ps(&m02);
567+
__m128 col3 = _mm_loadu_ps(&m03);
568+
569+
// transpose columns to rows for correct row-vector multiplication
570+
__m128 tmp0 = _mm_unpacklo_ps(col0, col1);
571+
__m128 tmp1 = _mm_unpackhi_ps(col0, col1);
572+
__m128 tmp2 = _mm_unpacklo_ps(col2, col3);
573+
__m128 tmp3 = _mm_unpackhi_ps(col2, col3);
574+
575+
__m128 row0 = _mm_movelh_ps(tmp0, tmp2);
576+
__m128 row1 = _mm_movehl_ps(tmp2, tmp0);
577+
__m128 row2 = _mm_movelh_ps(tmp1, tmp3);
578+
__m128 row3 = _mm_movehl_ps(tmp3, tmp1);
579+
580+
// broadcast vector components
581+
__m128 vx = _mm_set1_ps(rhs.x);
582+
__m128 vy = _mm_set1_ps(rhs.y);
583+
__m128 vz = _mm_set1_ps(rhs.z);
584+
__m128 vw = _mm_set1_ps(rhs.w);
585+
586+
// multiply and accumulate: result = row0*x + row1*y + row2*z + row3*w
587+
__m128 result = _mm_mul_ps(row0, vx);
588+
result = _mm_fmadd_ps(row1, vy, result);
589+
result = _mm_fmadd_ps(row2, vz, result);
590+
result = _mm_fmadd_ps(row3, vw, result);
591+
592+
return Vector4(_mm_cvtss_f32(result),
593+
_mm_cvtss_f32(_mm_shuffle_ps(result, result, _MM_SHUFFLE(1, 1, 1, 1))),
594+
_mm_cvtss_f32(_mm_shuffle_ps(result, result, _MM_SHUFFLE(2, 2, 2, 2))),
595+
_mm_cvtss_f32(_mm_shuffle_ps(result, result, _MM_SHUFFLE(3, 3, 3, 3))));
596+
#else
508597
return Vector4
509598
(
510599
(rhs.x * m00) + (rhs.y * m10) + (rhs.z * m20) + (rhs.w * m30),
511600
(rhs.x * m01) + (rhs.y * m11) + (rhs.z * m21) + (rhs.w * m31),
512601
(rhs.x * m02) + (rhs.y * m12) + (rhs.z * m22) + (rhs.w * m32),
513602
(rhs.x * m03) + (rhs.y * m13) + (rhs.z * m23) + (rhs.w * m33)
514603
);
604+
#endif
515605
}
516606

517607
bool operator==(const Matrix& rhs) const
518608
{
519-
const float* data_left = Data();
520-
const float* data_right = rhs.Data();
609+
#if defined(__AVX2__)
610+
// load and compare all 4 columns using 256-bit registers for efficiency
611+
__m256 left01 = _mm256_loadu_ps(&m00); // columns 0 and 1
612+
__m256 left23 = _mm256_loadu_ps(&m02); // columns 2 and 3
613+
__m256 right01 = _mm256_loadu_ps(&rhs.m00);
614+
__m256 right23 = _mm256_loadu_ps(&rhs.m02);
615+
616+
// compare for equality
617+
__m256 cmp01 = _mm256_cmp_ps(left01, right01, _CMP_EQ_OQ);
618+
__m256 cmp23 = _mm256_cmp_ps(left23, right23, _CMP_EQ_OQ);
619+
620+
// combine results and check all bits are set
621+
__m256 cmp_all = _mm256_and_ps(cmp01, cmp23);
622+
return _mm256_movemask_ps(cmp_all) == 0xFF;
623+
#else
624+
const float* data_left = Data();
625+
const float* data_right = rhs.Data();
521626

522627
for (unsigned i = 0; i < 16; ++i)
523628
{
@@ -526,12 +631,40 @@ namespace spartan::math
526631
}
527632

528633
return true;
634+
#endif
529635
}
530636

531637
bool operator!=(const Matrix& rhs) const { return !(*this == rhs); }
532638

533-
bool Equals(const Matrix& rhs)
639+
bool Equals(const Matrix& rhs) const
534640
{
641+
#if defined(__AVX2__)
642+
const float eps = std::numeric_limits<float>::epsilon();
643+
__m256 epsilon = _mm256_set1_ps(eps);
644+
645+
// load columns
646+
__m256 left01 = _mm256_loadu_ps(&m00);
647+
__m256 left23 = _mm256_loadu_ps(&m02);
648+
__m256 right01 = _mm256_loadu_ps(&rhs.m00);
649+
__m256 right23 = _mm256_loadu_ps(&rhs.m02);
650+
651+
// compute absolute difference
652+
__m256 diff01 = _mm256_sub_ps(left01, right01);
653+
__m256 diff23 = _mm256_sub_ps(left23, right23);
654+
655+
// absolute value (clear sign bit)
656+
__m256 sign_mask = _mm256_set1_ps(-0.0f);
657+
diff01 = _mm256_andnot_ps(sign_mask, diff01);
658+
diff23 = _mm256_andnot_ps(sign_mask, diff23);
659+
660+
// compare against epsilon
661+
__m256 cmp01 = _mm256_cmp_ps(diff01, epsilon, _CMP_LE_OQ);
662+
__m256 cmp23 = _mm256_cmp_ps(diff23, epsilon, _CMP_LE_OQ);
663+
664+
// combine and check all pass
665+
__m256 cmp_all = _mm256_and_ps(cmp01, cmp23);
666+
return _mm256_movemask_ps(cmp_all) == 0xFF;
667+
#else
535668
const float* data_left = Data();
536669
const float* data_right = rhs.Data();
537670

@@ -542,6 +675,7 @@ namespace spartan::math
542675
}
543676

544677
return true;
678+
#endif
545679
}
546680

547681
[[nodiscard]] const float* Data() const { return &m00; }
@@ -556,6 +690,6 @@ namespace spartan::math
556690
};
557691

558692
// reverse order operators
559-
inline Vector3 operator*(const Vector3& lhs, const Matrix& rhs) { return rhs * lhs; }
560-
inline Vector4 operator*(const Vector4& lhs, const Matrix& rhs) { return rhs * lhs; }
693+
inline Vector3 operator*(const Vector3& lhs, const Matrix& rhs) { return rhs * lhs; }
694+
inline Vector4 operator*(const Vector4& lhs, const Matrix& rhs) { return rhs * lhs; }
561695
}

0 commit comments

Comments
 (0)