Skip to content

Commit 73fbdf4

Browse files
committed
montgomery multiplication
1 parent 283025a commit 73fbdf4

File tree

1 file changed

+100
-84
lines changed

1 file changed

+100
-84
lines changed

content/english/hpc/number-theory/montgomery.md

Lines changed: 100 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
---
22
title: Montgomery Multiplication
33
weight: 4
4-
draft: true
54
---
65

76
Unsurprisingly, large fractions of computations in [modular arithmetic](../modular) are often spent on calculating the modulo operation, which is as slow as general integer division and typically taking 15-20 cycles, depending on the operand size.
@@ -87,73 +86,122 @@ This means that, after we normally multiply two numbers in the Montgomery space,
8786
8887
### Montgomery reduction
8988
90-
Assume that $r=2^{32}$, the modulo $n$ is 32-bit, and the number $x$ we need to reduce (multiply by $r^{-1}$ and take it modulo $n$) is the 64-bit the product of two 32-bit numbers.
89+
Assume that $r=2^{32}$, the modulo $n$ is 32-bit, and the number $x$ we need to reduce is 64-bit (the product of two 32-bit numbers). Our goal is to calculate $y = x \cdot r^{-1} \bmod n$.
9190
92-
By definition, $\gcd(n, r) = 1$, so we know that there are two numbers $r^{-1}$ and $n'$ in the $[0, n)$ range such that
91+
Since $r$ is coprime with $n$, we know that there are two numbers $r^{-1}$ and $n^\prime$ in the $[0, n)$ range such that
9392
9493
$$
95-
r \cdot r^{-1} + n \cdot n' = 1
94+
r \cdot r^{-1} + n \cdot n^\prime = 1
9695
$$
9796
98-
and both $r^{-1}$ and $n'$ can be computed using the [extended Euclidean algorithm](../euclid-extended).
97+
and both $r^{-1}$ and $n^\prime$ can be computed e. g. using the [extended Euclidean algorithm](../euclid-extended).
9998
100-
Using this identity, we can express $r \cdot r^{-1}$ as $(-n \cdot n' + 1)$ and write $x \cdot r^{-1}$ as
99+
Using this identity, we can express $r \cdot r^{-1}$ as $(1 - n \cdot n^\prime)$ and write $x \cdot r^{-1}$ as
101100
102101
$$
103102
\begin{aligned}
104103
x \cdot r^{-1} &= x \cdot r \cdot r^{-1} / r
105-
\\ &= x \cdot (-n \cdot n^{\prime} + 1) / r
106-
\\ &= (-x \cdot n \cdot n^{\prime} + x) / r
107-
\\ &\equiv (-x \cdot n \cdot n^{\prime} + l \cdot r \cdot n + x) / r \bmod n
108-
\\ &\equiv ((-x \cdot n^{\prime} + l \cdot r) \cdot n + x) / r \bmod n
104+
\\ &= x \cdot (1 - n \cdot n^{\prime}) / r
105+
\\ &= (x - x \cdot n \cdot n^{\prime} ) / r
106+
\\ &\equiv (x - x \cdot n \cdot n^{\prime} + k \cdot r \cdot n) / r &\pmod n &\;\;\text{(for any integer $k$)}
107+
\\ &\equiv (x - (x \cdot n^{\prime} - k \cdot r) \cdot n) / r &\pmod n
109108
\end{aligned}
110109
$$
111110
112-
The equivalences hold for any integer $l$. This means that we can add or subtract an arbitrary multiple of $r$ to $x \cdot n'$, or in other words, we can compute $q = x \cdot n'$ modulo $r$.
111+
Now, if we choose $k$ to be $\lfloor x \cdot n^\prime / r \rfloor$ (the upper 64 bits of the $x \cdot n^\prime$ product), it will cancel out, and $(k \cdot r - x \cdot n^{\prime})$ will simply be equal to $x \cdot n^{\prime} \bmod r$ (the lower 32 bits of $x \cdot n^\prime$), implying:
113112
114-
This gives us the following algorithm to compute $x \cdot r^{-1} \bmod n$:
113+
$$
114+
x \cdot r^{-1} \equiv (x - x \cdot n^{\prime} \bmod r \cdot n) / r
115+
$$
116+
117+
The algorithm itself just evaluates this formula, performing two multiplications to calculate $q = x \cdot n^{\prime} \bmod r$ and $m = q \cdot n$, and then subtracts it from $x$ and right-shifts the result to divide it by $r$.
118+
119+
The only remaining thing to handle is that the result may not be in the $[0, n)$ range; but since
120+
121+
$$
122+
x < n \cdot n < r \cdot n \implies x / r < n
123+
$$
124+
125+
and
126+
127+
$$
128+
m = q \cdot n < r \cdot n \implies m / r < n
129+
$$
130+
131+
it is guaranteed that
132+
133+
$$
134+
-n < (x - m) / r < n
135+
$$
136+
137+
Therefore, we can simply check if the result is negative and in that case, add $n$ to it, giving the following algorithm:
138+
139+
```c++
140+
typedef __uint32_t u32;
141+
typedef __uint64_t u64;
142+
143+
const u32 n = 1e9 + 7, nr = inverse(n, 1ull << 32);
115144
116-
```python
117-
def reduce(x):
118-
q = (x % r) * nr % r
119-
a = (x - q * n) / r
120-
if a < 0:
121-
a += n
122-
return a
145+
u32 reduce(u64 x) {
146+
u32 q = u32(x) * nr; // q = x * n' mod r
147+
u64 m = (u64) q * n; // m = q * n
148+
u32 y = (x - m) >> 32; // y = (x - m) / r
149+
return x < m ? y + n : y; // if y < 0, add n to make it be in the [0, n) range
150+
}
123151
```
124152

125-
Since $x < n \cdot n < r \cdot n$ and $q \cdot n < r \cdot n$, we know that
153+
This last check is relatively cheap, but it is still on the critical path. If we are fine with the result being in the $[0, 2 \cdot n - 2]$ range instead of $[0, n)$, we can remove it and add $n$ to the result unconditionally:
154+
155+
```c++
156+
u32 reduce(u64 x) {
157+
u32 q = u32(x) * nr;
158+
u64 m = (u64) q * n;
159+
u32 y = (x - m) >> 32;
160+
return y + n
161+
}
162+
```
163+
164+
We can also move the `>> 32` operation one step earlier in the computation graph and compute $\lfloor x / r \rfloor - \lfloor m / r \rfloor$ instead of $(x - m) / r$. This is correct because the lower 32 bits of $x$ and $m$ are equal anyway since
126165
127166
$$
128-
-n < (x - q \cdot n) / r < n
167+
m = x \cdot n^\prime \cdot n \equiv x \pmod r
129168
$$
130169
131-
Therefore, the final modulo operation can be implemented using a single bound check and addition.
170+
But why would we voluntarily choose to perfom two right-shifts instead of just one? This is beneficial because for `((u64) q * n) >> 32` we need to do a 32-by-32 multiplication and take the upper 32 bits of the result (which the x86 `mul` instruction [already writes](../hpc/arithmetic/integer/#128-bit-integers) in a separate register, so it doesn't cost anything), and the other right-shift `x >> 32` is not on the critical path.
171+
172+
```c++
173+
u32 reduce(u64 x) {
174+
u32 q = u32(x) * nr;
175+
u32 m = ((u64) q * n) >> 32;
176+
return (x >> 32) + n - m;
177+
}
178+
```
132179

133-
Here is an equivalent C implementation for 64-bit integers:
180+
One of the main advantages of Montgomery multiplication over other modular reduction methods is that it doesn't require very large data types: it only needs a $r \times r$ multiplication that extracts the lower and higher $r$ bits of the result, which [has special support](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#ig_expand=7395,7392,7269,4868,7269,7269,1820,1835,6385,5051,4909,4918,5051,7269,6423,7410,150,2138,1829,1944,3009,1029,7077,519,5183,4462,4490,1944,5055,5012,5055&techs=AVX,AVX2&text=mul) on most hardware also makes it easily generalizable to [SIMD](../hpc/simd/) and larger data types:
134181

135182
```c++
136-
typedef unsigned long long u64;
137183
typedef __uint128_t u128;
138184

139-
u64 reduce(u128 x) {
185+
u64 reduce(u128 x) const {
140186
u64 q = u64(x) * nr;
141187
u64 m = ((u128) q * n) >> 64;
142-
u64 xhi = (x >> 64);
143-
if (xhi >= m)
144-
return (xhi - m);
145-
else
146-
return (xhi - m) + n;
188+
return (x >> 64) + n - m;
147189
}
148190
```
149191
150-
We also need to implement calculating calculating the inverse of $n$ (`nr`) and transformation of numbers in and our of Montgomery space. Before providing complete implementation, let's discuss how to do that smarter, although they are just done once.
192+
Note that a 128-by-64 modulo is not possible with general integer division tricks: the compiler [falls back](https://godbolt.org/z/fbEE4v4qr) to calling a slow [long arithmetic library function](https://github.com/llvm-mirror/compiler-rt/blob/69445f095c22aac2388f939bedebf224a6efcdaf/lib/builtins/udivmodti4.c#L22) to support it.
193+
194+
### Faster Inverse and Transform
151195
152-
To transfer a number back from the Montgomery space we can just use Montgomery reduction.
196+
Montgomery multiplication itself is fast, but it requires some precomputation:
153197
154-
### Fast inverse
198+
- inverting $n$ modulo $r$ to compute $n^\prime$,
199+
- transforming a number *to* the Montgomery space,
200+
- transforming a number *from* the Montgomery space.
155201
156-
For computing the inverse $n' = n^{-1} \bmod r$ more efficiently, we can use the following trick inspired from the Newton's method:
202+
The last operation is already efficiently performed with the `reduce` procedure we just implemented, but the first two can be slightly optimized.
203+
204+
**Computing the inverse** $n^\prime = n^{-1} \bmod r$ can be done faster than with the extended Euclidean algorithm by taking advantage of the fact that $r$ is a power of two and using the following identity:
157205
158206
$$
159207
a \cdot x \equiv 1 \bmod 2^k
@@ -163,7 +211,7 @@ a \cdot x \cdot (2 - a \cdot x)
163211
1 \bmod 2^{2k}
164212
$$
165213
166-
This can be proven this way:
214+
Proof:
167215
168216
$$
169217
\begin{aligned}
@@ -176,41 +224,36 @@ a \cdot x \cdot (2 - a \cdot x)
176224
\end{aligned}
177225
$$
178226
179-
This means we can start with $x = 1$ as the inverse of $a$ modulo $2^1$, apply the trick a few times and in each iteration we double the number of correct bits of $x$.
180-
181-
### Fast transformation
227+
We can start with $x = 1$ as the inverse of $a$ modulo $2^1$ and apply this identity exactly $\log_2 r$ times, each time doubling the number of bits in the inverse — somewhat reminiscent of [the Newton's method](../hpc/arithmetic/newton/).
182228
183-
Although we can just multiply a number by $r$ and compute one modulo the usual way, there is a faster way that makes use of the following relation:
229+
**Transforming** a number into the Montgomery space can be done by multiplying it by $r$ and computing modulo [the usual way](../hpc/arithmetic/division/), but we can also take advantage of this relation:
184230
185231
$$
186232
\bar{x} = x \cdot r \bmod n = x * r^2
187233
$$
188234
189-
Transforming a number into the space is just a multiplication inside the space of the number with $r^2$. Therefore we can precompute $r^2 \bmod n$ and just perform a multiplication and reduction instead.
235+
Transforming a number into the space is just a multiplication by $r^2$. Therefore, we can precompute $r^2 \bmod n$ and perform a multiplication and reduction instead — which may or may not be actually faster because multiplying a number by $r=2^{k}$ can be implemented with a left-shift, while multiplication by $r^2 \bmod n$ can not.
190236
191237
### Complete Implementation
192238
193-
```c++
194-
typedef __uint32_t u32;
195-
typedef __uint64_t u64;
239+
It is convenient to wrap everything into a single `constexpr` structure:
196240
197-
struct montgomery {
241+
```c++
242+
struct Montgomery {
198243
u32 n, nr;
199244
200-
constexpr montgomery(u32 n) : n(n), nr(1) {
201-
for (int i = 0; i < 6; i++)
245+
constexpr Montgomery(u32 n) : n(n), nr(1) {
246+
// log(2^32) = 5
247+
for (int i = 0; i < 5; i++)
202248
nr *= 2 - n * nr;
203249
}
204250
205251
u32 reduce(u64 x) const {
206252
u32 q = u32(x) * nr;
207253
u32 m = ((u64) q * n) >> 32;
208-
u32 xhi = (x >> 32);
209-
return xhi + n - m;
210-
211-
// if you need
212-
// u32 t = xhi - m;
213-
// return xhi >= m ? t : t + n;
254+
return (x >> 32) + n - m;
255+
// returns a number in the [0, 2 * n - 2] range
256+
// (add a "x < n ? x : x - n" type of check if you need a proper modulo)
214257
}
215258
216259
u32 multiply(u32 x, u32 y) const {
@@ -219,44 +262,15 @@ struct montgomery {
219262
220263
u32 transform(u32 x) const {
221264
return (u64(x) << 32) % n;
265+
// can also be implemented as multiply(x, r^2 mod n)
222266
}
223267
};
224268
```
225269

226-
```c++
227-
montgomery m(n);
228-
229-
a = m.transform(a);
230-
b = m.transform(b);
231-
c = m.multiply(a, b);
232-
c = m.reduce(c);
233-
```
234-
235-
```c++
236-
int inverse(int _a) {
237-
u32 a = space.transform(_a);
238-
u32 r = space.transform(1);
239-
240-
int n = M - 2;
241-
while (n) {
242-
if (n & 1)
243-
r = space.multiply(r, a);
244-
a = space.multiply(a, a);
245-
n >>= 1;
246-
}
247-
248-
return space.reduce(r);
249-
}
250-
```
251-
252-
SIMD
253-
254-
166.79 ns
255-
256-
207.04 ns
270+
To test its performance, we can plug Montgomery multiplication into the [binary exponentiation](../hpc/number-theory/exponentiation/):
257271

258272
```c++
259-
constexpr montgomery space(M);
273+
constexpr Montgomery space(M);
260274

261275
int inverse(int _a) {
262276
u64 a = space.transform(_a);
@@ -273,4 +287,6 @@ int inverse(int _a) {
273287
}
274288
```
275289
290+
While vanilla binary exponentiation with a compiler-generated fast modulo trick requires ~170ns per `inverse` call, this implementation takes ~166ns, going down to ~158s we omit `transform` and `reduce` (a reasonable use case in modular arithmetic is for `inverse` to be used as a subprocedure in a bigger computation). This is a small improvement, but Montgomery multiplication becomes much more advantageous for SIMD applications and larger data types.
291+
276292
**Exercise.** Implement efficient *modular* [matix multiplication](/hpc/algorithms/matmul).

0 commit comments

Comments
 (0)