forked from memvid/memvid
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimd.rs
More file actions
139 lines (120 loc) · 3.62 KB
/
simd.rs
File metadata and controls
139 lines (120 loc) · 3.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
//! SIMD-accelerated distance calculations for vector search.
//!
//! This module provides optimized L2 (Euclidean) distance functions using
//! the `wide` crate for portable SIMD across `x86_64` and aarch64.
#[cfg(feature = "simd")]
use wide::f32x8;
/// Compute squared L2 distance between two f32 slices using SIMD.
///
/// Uses 8-wide SIMD lanes (AVX2 on `x86_64`, NEON on aarch64).
/// Falls back to scalar for remainder elements.
#[cfg(feature = "simd")]
#[must_use]
pub fn l2_distance_squared_simd(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "vectors must have same length");
let len = a.len();
let chunks = len / 8;
let remainder = len % 8;
let mut sum = f32x8::ZERO;
// Process 8 elements at a time
for i in 0..chunks {
let offset = i * 8;
let a_chunk = f32x8::new([
a[offset],
a[offset + 1],
a[offset + 2],
a[offset + 3],
a[offset + 4],
a[offset + 5],
a[offset + 6],
a[offset + 7],
]);
let b_chunk = f32x8::new([
b[offset],
b[offset + 1],
b[offset + 2],
b[offset + 3],
b[offset + 4],
b[offset + 5],
b[offset + 6],
b[offset + 7],
]);
let diff = a_chunk - b_chunk;
sum += diff * diff;
}
// Horizontal sum of the SIMD vector
let sum_array: [f32; 8] = sum.into();
let mut total: f32 = sum_array.iter().sum();
// Handle remainder elements with scalar code
let offset = chunks * 8;
for i in 0..remainder {
let diff = a[offset + i] - b[offset + i];
total += diff * diff;
}
total
}
/// Compute L2 distance (with sqrt) using SIMD.
#[cfg(feature = "simd")]
#[must_use]
pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
l2_distance_squared_simd(a, b).sqrt()
}
// Scalar fallbacks when SIMD feature is disabled
/// Compute squared L2 distance using scalar math.
#[cfg(not(feature = "simd"))]
pub fn l2_distance_squared_simd(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum()
}
/// Compute L2 distance using scalar math.
#[cfg(not(feature = "simd"))]
pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
l2_distance_squared_simd(a, b).sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_l2_distance_squared_basic() {
let a = [0.0, 0.0, 0.0];
let b = [3.0, 4.0, 0.0];
let dist_sq = l2_distance_squared_simd(&a, &b);
assert!(
(dist_sq - 25.0).abs() < 1e-6,
"expected 25.0, got {}",
dist_sq
);
}
#[test]
fn test_l2_distance_basic() {
let a = [0.0, 0.0];
let b = [3.0, 4.0];
let dist = l2_distance_simd(&a, &b);
assert!((dist - 5.0).abs() < 1e-6, "expected 5.0, got {}", dist);
}
#[test]
fn test_l2_distance_384_dims() {
// Test with realistic 384-dim vectors
let a: Vec<f32> = (0..384).map(|i| i as f32 * 0.01).collect();
let b: Vec<f32> = (0..384).map(|i| (i + 1) as f32 * 0.01).collect();
let dist_simd = l2_distance_simd(&a, &b);
// Compare with scalar implementation
let dist_scalar: f32 = a
.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt();
assert!(
(dist_simd - dist_scalar).abs() < 1e-4,
"SIMD {} vs Scalar {}",
dist_simd,
dist_scalar
);
}
}