You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Unsurprisingly, large fractions of computations in [modular arithmetic](../modular) are often spent on calculating the modulo operation, which is as slow as general integer division and typically taking 15-20 cycles, depending on the operand size.
@@ -87,73 +86,122 @@ This means that, after we normally multiply two numbers in the Montgomery space,
87
86
88
87
### Montgomery reduction
89
88
90
-
Assume that $r=2^{32}$, the modulo $n$ is 32-bit, and the number $x$ we need to reduce (multiply by $r^{-1}$ and take it modulo $n$) is the 64-bit the product of two 32-bit numbers.
89
+
Assume that $r=2^{32}$, the modulo $n$ is 32-bit, and the number $x$ we need to reduce is 64-bit (the product of two 32-bit numbers). Our goal is to calculate $y = x \cdot r^{-1} \bmod n$.
91
90
92
-
By definition, $\gcd(n, r) = 1$, so we know that there are two numbers $r^{-1}$ and $n'$ in the $[0, n)$ range such that
91
+
Since $r$ is coprime with $n$, we know that there are two numbers $r^{-1}$ and $n^\prime$ in the $[0, n)$ range such that
93
92
94
93
$$
95
-
r \cdot r^{-1} + n \cdot n' = 1
94
+
r \cdot r^{-1} + n \cdot n^\prime = 1
96
95
$$
97
96
98
-
and both $r^{-1}$ and $n'$ can be computed using the [extended Euclidean algorithm](../euclid-extended).
97
+
and both $r^{-1}$ and $n^\prime$ can be computed e. g. using the [extended Euclidean algorithm](../euclid-extended).
99
98
100
-
Using this identity, we can express $r \cdot r^{-1}$ as $(-n \cdot n' + 1)$ and write $x \cdot r^{-1}$ as
99
+
Using this identity, we can express $r \cdot r^{-1}$ as $(1 - n \cdot n^\prime)$ and write $x \cdot r^{-1}$ as
101
100
102
101
$$
103
102
\begin{aligned}
104
103
x \cdot r^{-1} &= x \cdot r \cdot r^{-1} / r
105
-
\\ &= x \cdot (-n \cdot n^{\prime} + 1) / r
106
-
\\ &= (-x \cdot n \cdot n^{\prime} + x) / r
107
-
\\ &\equiv (-x \cdot n \cdot n^{\prime} + l \cdot r \cdot n + x) / r \bmod n
108
-
\\ &\equiv ((-x \cdot n^{\prime} + l \cdot r) \cdot n + x) / r \bmod n
104
+
\\ &= x \cdot (1 - n \cdot n^{\prime}) / r
105
+
\\ &= (x - x \cdot n \cdot n^{\prime} ) / r
106
+
\\ &\equiv (x - x \cdot n \cdot n^{\prime} + k \cdot r \cdot n) / r &\pmod n &\;\;\text{(for any integer $k$)}
107
+
\\ &\equiv (x - (x \cdot n^{\prime} - k \cdot r) \cdot n) / r &\pmod n
109
108
\end{aligned}
110
109
$$
111
110
112
-
The equivalences hold for any integer $l$. This means that we can add or subtract an arbitrary multiple of $r$ to $x \cdot n'$, or in other words, we can compute $q = x \cdot n'$ modulo $r$.
111
+
Now, if we choose $k$ to be $\lfloor x \cdot n^\prime / r \rfloor$ (the upper 64 bits of the $x \cdot n^\prime$ product), it will cancel out, and $(k \cdot r - x \cdot n^{\prime})$ will simply be equal to $x \cdot n^{\prime} \bmod r$ (the lower 32 bits of $x \cdot n^\prime$), implying:
113
112
114
-
This gives us the following algorithm to compute $x \cdot r^{-1} \bmod n$:
113
+
$$
114
+
x \cdot r^{-1} \equiv (x - x \cdot n^{\prime} \bmod r \cdot n) / r
115
+
$$
116
+
117
+
The algorithm itself just evaluates this formula, performing two multiplications to calculate $q = x \cdot n^{\prime} \bmod r$ and $m = q \cdot n$, and then subtracts it from $x$ and right-shifts the result to divide it by $r$.
118
+
119
+
The only remaining thing to handle is that the result may not be in the $[0, n)$ range; but since
120
+
121
+
$$
122
+
x < n \cdot n < r \cdot n \implies x / r < n
123
+
$$
124
+
125
+
and
126
+
127
+
$$
128
+
m = q \cdot n < r \cdot n \implies m / r < n
129
+
$$
130
+
131
+
it is guaranteed that
132
+
133
+
$$
134
+
-n < (x - m) / r < n
135
+
$$
136
+
137
+
Therefore, we can simply check if the result is negative and in that case, add $n$ to it, giving the following algorithm:
138
+
139
+
```c++
140
+
typedef __uint32_t u32;
141
+
typedef __uint64_t u64;
142
+
143
+
const u32 n = 1e9 + 7, nr = inverse(n, 1ull << 32);
115
144
116
-
```python
117
-
def reduce(x):
118
-
q = (x % r) * nr % r
119
-
a = (x - q * n) / r
120
-
if a < 0:
121
-
a += n
122
-
return a
145
+
u32 reduce(u64 x) {
146
+
u32 q = u32(x) * nr; // q = x * n' mod r
147
+
u64 m = (u64) q * n; // m = q * n
148
+
u32 y = (x - m) >> 32; // y = (x - m) / r
149
+
return x < m ? y + n : y; // if y < 0, add n to make it be in the [0, n) range
150
+
}
123
151
```
124
152
125
-
Since $x < n \cdot n < r \cdot n$ and $q \cdot n < r \cdot n$, we know that
153
+
This last check is relatively cheap, but it is still on the critical path. If we are fine with the result being in the $[0, 2 \cdot n - 2]$ range instead of $[0, n)$, we can remove it and add $n$ to the result unconditionally:
154
+
155
+
```c++
156
+
u32reduce(u64 x) {
157
+
u32 q = u32(x) * nr;
158
+
u64 m = (u64) q * n;
159
+
u32 y = (x - m) >> 32;
160
+
return y + n
161
+
}
162
+
```
163
+
164
+
We can also move the `>> 32` operation one step earlier in the computation graph and compute $\lfloor x / r \rfloor - \lfloor m / r \rfloor$ instead of $(x - m) / r$. This is correct because the lower 32 bits of $x$ and $m$ are equal anyway since
126
165
127
166
$$
128
-
-n < (x - q \cdot n) / r < n
167
+
m = x \cdot n^\prime \cdot n \equiv x \pmod r
129
168
$$
130
169
131
-
Therefore, the final modulo operation can be implemented using a single bound check and addition.
170
+
But why would we voluntarily choose to perfom two right-shifts instead of just one? This is beneficial because for `((u64) q * n) >> 32` we need to do a 32-by-32 multiplication and take the upper 32 bits of the result (which the x86 `mul` instruction [already writes](../hpc/arithmetic/integer/#128-bit-integers) in a separate register, so it doesn't cost anything), and the other right-shift `x >> 32` is not on the critical path.
171
+
172
+
```c++
173
+
u32 reduce(u64 x) {
174
+
u32 q = u32(x) * nr;
175
+
u32 m = ((u64) q * n) >> 32;
176
+
return (x >> 32) + n - m;
177
+
}
178
+
```
132
179
133
-
Here is an equivalent C implementation for 64-bit integers:
180
+
One of the main advantages of Montgomery multiplication over other modular reduction methods is that it doesn't require very large data types: it only needs a $r \times r$ multiplication that extracts the lower and higher $r$ bits of the result, which [has special support](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#ig_expand=7395,7392,7269,4868,7269,7269,1820,1835,6385,5051,4909,4918,5051,7269,6423,7410,150,2138,1829,1944,3009,1029,7077,519,5183,4462,4490,1944,5055,5012,5055&techs=AVX,AVX2&text=mul) on most hardware also makes it easily generalizable to [SIMD](../hpc/simd/) and larger data types:
134
181
135
182
```c++
136
-
typedefunsignedlonglongu64;
137
183
typedef__uint128_tu128;
138
184
139
-
u64reduce(u128 x) {
185
+
u64reduce(u128 x) const {
140
186
u64 q = u64(x) * nr;
141
187
u64 m = ((u128) q * n) >> 64;
142
-
u64 xhi = (x >> 64);
143
-
if (xhi >= m)
144
-
return (xhi - m);
145
-
else
146
-
return (xhi - m) + n;
188
+
return (x >> 64) + n - m;
147
189
}
148
190
```
149
191
150
-
We also need to implement calculating calculating the inverse of $n$ (`nr`) and transformation of numbers in and our of Montgomery space. Before providing complete implementation, let's discuss how to do that smarter, although they are just done once.
192
+
Note that a 128-by-64 modulo is not possible with general integer division tricks: the compiler [falls back](https://godbolt.org/z/fbEE4v4qr) to calling a slow [long arithmetic library function](https://github.com/llvm-mirror/compiler-rt/blob/69445f095c22aac2388f939bedebf224a6efcdaf/lib/builtins/udivmodti4.c#L22) to support it.
193
+
194
+
### Faster Inverse and Transform
151
195
152
-
To transfer a number back from the Montgomery space we can just use Montgomery reduction.
196
+
Montgomery multiplication itself is fast, but it requires some precomputation:
153
197
154
-
### Fast inverse
198
+
- inverting $n$ modulo $r$ to compute $n^\prime$,
199
+
- transforming a number *to* the Montgomery space,
200
+
- transforming a number *from* the Montgomery space.
155
201
156
-
For computing the inverse $n' = n^{-1} \bmod r$ more efficiently, we can use the following trick inspired from the Newton's method:
202
+
The last operation is already efficiently performed with the `reduce` procedure we just implemented, but the first two can be slightly optimized.
203
+
204
+
**Computing the inverse** $n^\prime = n^{-1} \bmod r$ can be done faster than with the extended Euclidean algorithm by taking advantage of the fact that $r$ is a power of two and using the following identity:
157
205
158
206
$$
159
207
a \cdot x \equiv 1 \bmod 2^k
@@ -163,7 +211,7 @@ a \cdot x \cdot (2 - a \cdot x)
163
211
1 \bmod 2^{2k}
164
212
$$
165
213
166
-
This can be proven this way:
214
+
Proof:
167
215
168
216
$$
169
217
\begin{aligned}
@@ -176,41 +224,36 @@ a \cdot x \cdot (2 - a \cdot x)
176
224
\end{aligned}
177
225
$$
178
226
179
-
This means we can start with $x = 1$ as the inverse of $a$ modulo $2^1$, apply the trick a few times and in each iteration we double the number of correct bits of $x$.
180
-
181
-
### Fast transformation
227
+
We can start with $x = 1$ as the inverse of $a$ modulo $2^1$ and apply this identity exactly $\log_2 r$ times, each time doubling the number of bits in the inverse — somewhat reminiscent of [the Newton's method](../hpc/arithmetic/newton/).
182
228
183
-
Although we can just multiply a number by $r$ and compute one modulo the usual way, there is a faster way that makes use of the following relation:
229
+
**Transforming** a number into the Montgomery space can be done by multiplying it by $r$ and computing modulo [the usual way](../hpc/arithmetic/division/), but we can also take advantage of this relation:
184
230
185
231
$$
186
232
\bar{x} = x \cdot r \bmod n = x * r^2
187
233
$$
188
234
189
-
Transforming a number into the space is just a multiplication inside the space of the number with $r^2$. Therefore we can precompute $r^2 \bmod n$ and just perform a multiplication and reduction instead.
235
+
Transforming a number into the space is just a multiplication by $r^2$. Therefore, we can precompute $r^2 \bmod n$ and perform a multiplication and reduction instead — which may or may not be actually faster because multiplying a number by $r=2^{k}$ can be implemented with a left-shift, while multiplication by $r^2 \bmod n$ can not.
190
236
191
237
### Complete Implementation
192
238
193
-
```c++
194
-
typedef __uint32_t u32;
195
-
typedef __uint64_t u64;
239
+
It is convenient to wrap everything into a single `constexpr` structure:
196
240
197
-
struct montgomery {
241
+
```c++
242
+
struct Montgomery {
198
243
u32 n, nr;
199
244
200
-
constexpr montgomery(u32 n) : n(n), nr(1) {
201
-
for (int i = 0; i < 6; i++)
245
+
constexpr Montgomery(u32 n) : n(n), nr(1) {
246
+
// log(2^32) = 5
247
+
for (int i = 0; i < 5; i++)
202
248
nr *= 2 - n * nr;
203
249
}
204
250
205
251
u32 reduce(u64 x) const {
206
252
u32 q = u32(x) * nr;
207
253
u32 m = ((u64) q * n) >> 32;
208
-
u32 xhi = (x >> 32);
209
-
return xhi + n - m;
210
-
211
-
// if you need
212
-
// u32 t = xhi - m;
213
-
// return xhi >= m ? t : t + n;
254
+
return (x >> 32) + n - m;
255
+
// returns a number in the [0, 2 * n - 2] range
256
+
// (add a "x < n ? x : x - n" type of check if you need a proper modulo)
214
257
}
215
258
216
259
u32 multiply(u32 x, u32 y) const {
@@ -219,44 +262,15 @@ struct montgomery {
219
262
220
263
u32 transform(u32 x) const {
221
264
return (u64(x) << 32) % n;
265
+
// can also be implemented as multiply(x, r^2 mod n)
222
266
}
223
267
};
224
268
```
225
269
226
-
```c++
227
-
montgomery m(n);
228
-
229
-
a = m.transform(a);
230
-
b = m.transform(b);
231
-
c = m.multiply(a, b);
232
-
c = m.reduce(c);
233
-
```
234
-
235
-
```c++
236
-
int inverse(int _a) {
237
-
u32 a = space.transform(_a);
238
-
u32 r = space.transform(1);
239
-
240
-
int n = M - 2;
241
-
while (n) {
242
-
if (n & 1)
243
-
r = space.multiply(r, a);
244
-
a = space.multiply(a, a);
245
-
n >>= 1;
246
-
}
247
-
248
-
return space.reduce(r);
249
-
}
250
-
```
251
-
252
-
SIMD
253
-
254
-
166.79 ns
255
-
256
-
207.04 ns
270
+
To test its performance, we can plug Montgomery multiplication into the [binary exponentiation](../hpc/number-theory/exponentiation/):
257
271
258
272
```c++
259
-
constexprmontgomeryspace(M);
273
+
constexprMontgomeryspace(M);
260
274
261
275
int inverse(int _a) {
262
276
u64 a = space.transform(_a);
@@ -273,4 +287,6 @@ int inverse(int _a) {
273
287
}
274
288
```
275
289
290
+
While vanilla binary exponentiation with a compiler-generated fast modulo trick requires ~170ns per `inverse` call, this implementation takes ~166ns, going down to ~158s we omit `transform` and `reduce` (a reasonable use case in modular arithmetic is for `inverse` to be used as a subprocedure in a bigger computation). This is a small improvement, but Montgomery multiplication becomes much more advantageous for SIMD applications and larger data types.
0 commit comments