競プロ備忘録

競プロerの備忘録

Montgomery剰余乗算の学習メモ

Montgomery剰余乗算のアルゴリズムを用いると、乗算剰余を計算する際に純粋な除算・剰余算を省くことができる。

用途は色々あって、

  • 128bit剰余算は非常に遅いので、そのような演算が必要な場合には高速化が期待できる
  • SSEやAVX/AVX2には整数除算命令・剰余命令がないが、Montgomery剰余乗算によって一部の演算をカバーできる

などがあげられる。

既存の優れた記事がたくさんあるが、自分には理解が難しい部分があったので、メモ。

以下、法を Nとする。

Montgomery表現

一般に通じる用語かどうかは定かでないですが、この後使うので、もし一般用語じゃなければこの記事の中で使う用語の定義だと思ってください。

新たな定数として、 N \lt Rであり、 gcd(N, R) = 1であるような Rを置きます。これを用いて、ある数 a Rを乗じた数 aR \pmod Nを、 aのMontgomery表現と呼びます。

Montgomery表現での加算・減算は、法 Nにおける通常の加算・減算と同様に行えます。
 a + b \equiv c \pmod Nであれば、 aR + bR \equiv (a + b)R \equiv cR \pmod Nとなり、元の表現における演算結果をMontgomery表現に変換した結果となります。減算についても同様です。

乗算は少し違って、 a * b \equiv c \pmod Nであるとき、 aR * bR \equiv abR ^ {2} \equiv cR ^ {2} \pmod Nであって、元の表現での結果を単純にMontgomery表現に写したものにはなっていません。
なんとかしてこれに R ^ {-1}を乗じたいですが、そこで用いるのがMontgomery Reductionというアルゴリズムです。

Montgomery Reduction

Montgomery Reductionは、ある数 T (0 \le T \lt NR)に対して、 TR ^ {-1} \pmod Nを求めるアルゴリズムです。つまり、以下のように定義されます。

 MR(T) = TR ^ {-1} \pmod N

そして、Montgomery Reductionの具体的な手続きは、以下のようになります。
ただし、 NN' \equiv -1 \pmod Rとします。 gcd(N, R) = 1であり、 Rを法とする Nの乗法逆元は存在するので、このような N'も必ず存在します。

MR(T):
    t = (T + (TN' mod R) * N) / R
    if t >= N: then
        return t - N
    else: then
        return t

疑似コード中の tの範囲は、 0 \le t \lt 2Nとなります。

 N Rがどんな数でもよい…とすると、上述のアルゴリズムを用いても除算・剰余算は必要です。
しかし、 R 2 ^ {k} ( kは正の整数)であるとき、 Rによる除算は k bitの右シフト、剰余算は下位 k bit分のビットマスクになるので、非常に低コストになります。この制約を課すと、自動的に Nは奇数である必要が出てきます。( gcd(N, R) = 1より)

以下、 R = 2 ^ {32}とし、値の型は符号なし32bit整数を想定します。この場合、下位 32 bit分のビットマスクはラップアラウンドに置き換わり、さらに低コストです。
Rustではu32::wrapping_mulという関数があるので、以下のように書けます。

const N: u32 = {/* 法Nの値 */};
const N_PRIME: u32 = { /* NN' = -1となるN'を求める */};
fn mr(t: u32) -> u32 {
    let t = ((t as u64 + t.wrapping_mul(N_PRIME) as u64 * N as u64) >> 32) as u32;
    if t >= N { t - N } else { t }
}

Montgomery Reductionの手続きの正当性

文字の使い方は、上記疑似コード中のものに準じます。

まず、 tの右辺についてみると、

 T + (TN' mod R) * N \equiv T + TNN' \equiv T - T \equiv 0 \pmod R

であり、分子は Rの倍数ですから、 tは必ず整数です。
次に、 tの式の両辺に Rを掛けると、

 tR \equiv T + (TN' mod R) * N \equiv T + 0 \equiv T \pmod N

であり、 gcd(N, R) = 1より、法を Nとする Rの乗法逆元 R ^ {-1}が存在して、

 t \equiv TR ^ {-1} \pmod N

であることから、 MR(T)の手続きが TR ^ {-1}を結果として返すことが示せます。

最後に出力の範囲を考えると、 0 \le T \lt NRなので、

 0 \le TN' \mod R \lt R
 0 \le (TN' \mod R) * N \lt NR
 0 \le T + (TN' \mod R) * N \lt 2NR

よって、 0 \le t = (T + (TN' \mod R) * N) / R \lt 2Nが示せました。

Montgomery剰余乗算

 a * b \equiv c \pmod Nとしたとき、 a, bのMontgomery表現 aR, bRを単純に乗じると、

 aR * bR \equiv a * b * R ^ {2} \equiv cR ^ {2}

となりましたが、この結果に対してMontgomery Reductionを適用することで、

 MR(cR ^ {2}) \equiv cR ^ {2}R ^ {-1} \equiv cR

となるので、元の表現 a, bの乗算結果を正しくMontgomery表現に写すことができます。いわゆるMontgomery剰余乗算というアルゴリズムです。

Montgomery表現への変換

 aから aRへの変換もただの乗算なので、Montgomery剰余乗算によって求められます。
定数として R ^ {2}を用意しておき、 MR(aR ^ {2})を計算すればよいです。

各定数の求め方

 R \pmod Nは普通に剰余を求めればよいでしょう。

 R ^ {2} \pmod Nも同様普通に求めても良いですが、 R = 2 ^ {32}であることを用いると、 R ^ {2} \equiv R ^ {2} - N \pmod Nであり、 R ^ {2} - Nの64bit符号なし整数における表現は -Nなので、64bit符号なし整数における -N \mod Nを求めることでも代えられるようです。
頭良すぎて、最初全く意味わかりませんでした…

 N' \pmod Rですが、これは NN ^ {-1} \equiv 1 \pmod Rとすれば N' \equiv -N ^ {-1} \pmod Rですから、法 Rにおける Nの乗法逆元を求めればよいです。
 gcd(N, R) = 1よりこのような N ^ {-1}は必ず存在して、拡張ユークリッドの互除法や、オイラーの定理ニュートン法などを用いて求めることができます。

拡張ユークリッドの互除法

 R, Nに関する方程式 Rx + NN ^ {-1} = 1を解くことで、 N ^ {-1}の具体的な値が求まります。

オイラーの定理

 R Nは互いに素なので、オイラーの定理より、オイラーのトーシェント関数  \phi (N) := ( 1以上 N以下であって Nと互いに素な自然数の個数)を用いて、

 N ^ {\phi (R)} \equiv 1 \pmod R

となります。 R = 2 ^ {32}なので、 Rは偶数とは互いに素でなく、奇数とはすべて互いに素です。
したがって、 \phi (R) = 2 ^ {31}となり、 NN ^ {2 ^ {31} - 1} \equiv 1 \pmod Rなので、繰返し二乗法などで求められます。

実装例: Rust Playground

ニュートン法

ニュートン法を用いた有理数の乗法逆元の求め方をそのまま応用します。
すなわち、 f(x) = \frac{1}{x} - Nと置き、漸化式  x _ {n+1} = x _ {n} - \frac{f(x _ {n})}{f'(x _ {n})} ( f'(x) f(x)の1次導関数)に従って、乗法逆元を導出します。

漸化式に関数 f(x), f'(x)を代入して整理すると、 x _ {n+1} = x _ {n} (2 - Nx _ {n})となります。
 Nx _ {n} \equiv 1 \pmod pとすると、任意の正の整数 kを用いて、 Nx _ {n} = 1 + kpと表せます。これを漸化式に代入すると、

 Nx _ {n+1} = Nx _ {n} (2 - Nx _ {n}) = (1 + kp)(1-kp) = 1 - k ^ {2}p ^ {2} \equiv 1 \pmod {p ^ {2}}

となり、法 pにおける Nの乗法逆元 x _ {n}がわかれば、漸化式のステップを一つ進めることで、法 p ^ {2}における Nの乗法逆元が求まることがわかります。

ところで、 Nは奇数なので、 N = 1 + 2k (kは正の整数)などと表せ、

 N ^ 2 = (1 + 2k) ^ 2 = 1 + 4(k+k ^ {2}) \equiv 1 \pmod 4

より、 Nの法4における乗法逆元は N自身であることがわかり、これを上述の漸化式の初期値 x _ {0}とすることで、

 Nx _ {0} \equiv 1 \pmod {2 ^ {2}}
 Nx _ {1} \equiv 1 \pmod {2 ^ {4}}
 Nx _ {2} \equiv 1 \pmod {2 ^ {8}}
 Nx _ {3} \equiv 1 \pmod {2 ^ {16}}
 Nx _ {4} \equiv 1 \pmod {2 ^ {32}}

となり、たった4ステップで法 Rにおける Nの乗法逆元が求まります。
考えた人は天才ですね…

※もっとちゃんと考えると、 k + k ^ {2} = k(k+1)は必ず偶数ですから、もう1つ 2をくくり出しても良いことがわかり、実は4ステップで法 2 ^ {48}における乗法逆元まで求まっていることがわかります。

異なる手続きのMontgomery Reduction

以下のMizarさんの記事で述べられていた方法です。(勝手に人の記事のリンクなんて張ってよいのか?アウトならご指摘ください…)

64bit数の素数判定

詳しくはMizarさんの記事を読んでいただきたいですが、簡単に述べると、最初に紹介した手続きにおける N'の代わりに、純粋に法 Rにおける Nの乗法逆元を用い、以下のような手続きに従って結果を返します。

MR(T)
    t = floor(T/R) - floor(((TN^(-1) mod R) * N) / R)
    if t < 0: then
        return t + N
    else: then
        return t

 tの範囲は、 -N \lt t \lt Nとなります。

手続きの正当性

詳細はMizarさんの記事を見てください。正当性を示す流れは最初に示したアルゴリズムと同じです。

個人的にわからなかったのは、なんで床関数を T (TN ^ {-1} \mod R) * Nに別々に適用しても良いのか?ということです。ちゃんと示します。

まず、 T - (TN ^ {-1} \mod R) * N \equiv 0 \pmod Rより、 T \equiv (TN ^ {-1} \mod R) * N \pmod Rです。よって、

 \begin{cases}
{T = R \lfloor \frac {T}{R} \rfloor + r} \\
{(TN ^ {-1} \mod R) * N = R \lfloor \frac {(TN ^ {-1} \mod R) * N}{R} \rfloor + r}
\end{cases} (ただし、 r = T \mod R)

という2本の式ができ、これを変形して、

 \begin{cases}
{\lfloor \frac {T}{R} \rfloor = \frac{T - r}{R}} \\
{\lfloor \frac {(TN ^ {-1} \mod R) * N}{R} \rfloor = \frac{(TN ^ {-1} \mod R) * N - r}{R}}
\end{cases}

よって、

 \lfloor \frac {T}{R} \rfloor - \lfloor \frac {(TN ^ {-1} \mod R) * N}{R} \rfloor = \frac{T - r}{R} - \frac{(TN ^ {-1} \mod R) * N - r}{R} = \frac{T - (TN ^ {-1} \mod R) * N}{R}

となり、床関数を適用しない場合と同じ結果が得られました。
直感的にも、実数の引き算で結果が整数ならば小数点以下は必ず等しそうなので、そう思えば納得できる気がしました。足し算の場合には繰り上がりで整数になることがありますから、先に床関数を適用すると答えが小さくなるかもしれません。

このアルゴリズムの利点

冒頭で定義した手続きでは、 tの範囲が 0 \le t \lt 2Nでした。もし法 Nが型のサイズギリギリの場合には、 tの計算過程でオーバーフローが起こる可能性があります。
特に64bit整数にMontmery Reductionを適用する場合には、 tの結果がオーバーフローによるものか、純粋に Nより小さいのか、条件分けが複雑になることがあります。
一方でこのアルゴリズムでは、範囲がN未満であることがわかっており、最後の減算のタイミングでオーバーフロー検出を行えば十分なので、条件わけが単純です。オーバーフロー検出には、例えばRustであればu32::overflowing_subなどを用いることができます。

上方向のオーバーフローにもu32::overflowing_addなどがありますが、AVX2などのベクトル命令には符号なし整数の加減算命令がないことや、オーバーフロー検出用の命令がないことから、問題がさらに面倒なことになります。
このアルゴリズムの場合、AVX2には符号なし整数用のmax/min演算が用意されているため、減算の前に大小比較を行うことで問題は単純になります。

正規化の遅延

冒頭のアルゴリズムの話に戻ります。

 MR(T)の入出力範囲に注目すると、入力範囲が出力範囲を包含していることがわかります。
出力される値も \mod Nでは等しいですから、最後に Nを引かなくても、次のMontgomery Reductionで正しい答えが得られることが保証できます。

正規化の遅延を行う場合、加算・減算では、 Nではなく 2Nを境界値として正規化を行う必要があります。

2倍まで値が膨れることを許容できないといけないので、32bit整数では N \le 2 ^ {31}である必要があります。

2つ目のアルゴリズムの場合、出力範囲は異なりますが、入力範囲は同じなため、正規化のタイミングで問答無用で +Nしてやれば、同じ結果が得られます。
ただ、せっかく値域が Nで収まるという利点があるのに、保持しなければいけない値の範囲は1つめのアルゴリズムと同様 0 \le T \lt 2Nまで広がりますから、メリットがあるかというと微妙かもしれません。

Montgomery Reductionのベクトル化

 Nが32bit整数の範囲で収まるなら、演算は全体として64bitで収まりますし、 Nが定数である場合には最適化も効くでしょうから、わざわざMontgomery剰余乗算を実装するのは微妙かもしれません。
しかし、NTT(Number Theoretic Transform)など、乗算剰余を得る必要があるアルゴリズムをSSE, AVX/AVX2などのベクトル命令で実装する場合、Montgomery剰余乗算が目的を達成する一手段になります。

以下、Rustのstd::arch::x86_64に用意されたAVX2までの範囲の命令を用いて、__m256i型で保持した8個の要素にまとめてMontgomery Reductionを適用する例です。

定数の用意

8個の32bit整数をベクトルレジスタに詰める_mm256_set1_epi32という命令があります。求めた定数をこれで8個分詰めればよいです。

なお、RustはC/C++と違って、intrinsicをconst文脈で使えません。constでベクトルレジスタを保持したい場合、Rust1.56.0以降ではstd::mem::transmuteを用いて、配列から生成することができます。

const N: u32 = {/* 法N */};
const Nx8: __m256i = unsafe { std::mem::transmute([N; 8]) };

AtCoderでもそろそろ言語アップデートが実施されるという話がありますが、今のところRustのバージョンは大昔の1.42.0です。
この場合、共用体を用いたトリックめいたコードで生成が可能です。

union ConstSimd {
    arr: [u32; 8],
    reg: __m256i
}
const N: u32 = {/* 法N */};
const Nx8: __m256i = unsafe { ConstSimd { arr: [N; 8] }.reg };

constとはいっても、計算時のレジスタへのloadにはそれなりのコストがあり、レジスタの本数も16本しかありませんから、あまり増やし過ぎると逆効果な気もしますが、詳しいことはわかりません。詳しい方、教えてください…

 t = (T + (TN' \mod R) * N) / Rのパート

 TN' \mod Rについては、 T, N'をかけて下位32bit分を取得できれば良いです。これは32bit整数8個分をそれぞれ64bitの範囲に拡張した乗算結果の下位32bitを得る_mm256_mullo_epi32という命令1つで得られます。

その次の * Nについては64bit全体の結果を得る必要があります。これには_mm256_mul_epu32が使え、レジスタの0, 2, 4, 6番目の要素を符号なし64bit拡張した値の乗算結果を得ることができます。
1, 3, 5, 7番目については、引数として得た Tに_mm256_shuffle_epi32を用いて1, 3, 5, 7番目の値を0, 2, 4, 6番目に移動し、同じく_mm256_mul_epu32で結果を得ます。

 Tの加算も64bitの範囲で行う必要があるので、_mm256_add_epi64を使います。 Nの乗算の時点で8個の値は2つのレジスタにそれぞれ64bit整数4つ分として格納されているので、そのまま足せばよいです。
 Tは64bit整数4つ分として用意する必要がありますが、先ほど用いたshuffleと、後述するblendと、すべてのビットを0クリアしたレジスタを用意する_mm256_setzero_si256という命令を用いて対応します。

最後は32bit右シフトとレジスタの合成を行います。これには先ほど用いたshuffleと、8bitの定数に従って32bit整数8個を2つのレジスタから選択して合成する_mm256_blend_epi32という命令を使えます。

まとめると、以下のようなコードになります。

#[target_feature(enable = "avx2")]
unsafe fn montgomery_reduction_u32x8_without_normalization(t: __m256i) -> __m256i {
    let t_nprime = _mm256_mullo_epi32(t, N_RRIMEx8);
    let t_nprime_n_lo = _mm256_mul_epu32(t_nprime, Nx8);
    let t_nprime_n_hi = _mm256_mul_epu32(_mm256_shuffle_epi32(t_nprime, 0b10_11_00_01), Nx8);
    let res_lo = _mm256_add_epi64(t_nprime_n_lo, _mm256_blend_epi32(t, _mm256_setzero_si256(), 0b10101010));
    let res_hi = _mm256_add_epi64(t_nprime_n_hi, _mm256_blend_epi32(_mm256_shuffle_epi32(t, 0b10_11_00_01), _mm256_setzero_si256(), 0b10101010);
    _mm256_blend_epi32(_mm256_shuffle_epi32(res_lo, 0b10_11_00_01), res_hi)
}

順序に沿って実装しましたが、レジスタをカツカツで使っていると無駄なload/storeが挟まれて実行速度が急速に低下するかもしれません。
その場合はレジスタの使い方が最適になるように並べ替えたほうが良いかもしれません。(最適化でそのくらいよしなにやってくれる気がしますが…)

正規化パート

AVX命令に条件分岐はなく、代わりに条件を満たす部分にマスクをかけたレジスタを渡してくれる命令を用います。
今回は t N以上の箇所にのみ Nを引くという処理をしたいですが、mm256_cmpge_epi32/mm256_cmplt_epi32という命令はないので、他の命令を組み合わせて実装する必要があります。私の考えたやり方は2パターンあります。

1つ目は_mm256_cmpgt_epi32と_mm256_cmpeq_epi32の出力結果を_mm256_or_si256でまぜる方法、2つ目は_mm256_max_epu32で t Nの大きいほうを出力し、それを_mm256_cmpeq_epi32で tと比較する方法です。

AVX2までには符号なし整数比較命令がないため1つ目では符号付き整数比較命令を使っていますが、 t \ge 2 ^ {31}でバグります。
max/min演算については、_mm256_max_epu32, _mm256_min_epu32という命令があるため、2つ目の方法なら tが完全に32bitに収まっている限りは正確な結果を出力することができます。

以下は2つ目の方法での正規化です。

#[target_feature(enable = "avx2")]
unsafe fn normalization(t: __m256i) -> __m256i {
    let mask = _mm256_cmpeq(t, _mm256_max_epu32(t, Nx8));
    _mm256_sub_epi32(t, _mm256_and_si256(Nx8, mask))
}

 N \ge 2 ^ {31}の場合や、正規化を遅延している場合は加減算でオーバーフローすることがあるので、先に引数の大小を比較してmaskを取得するなどが必要かもしれません。

Montgomery剰余乗算のベクトル化

Montgomery Reductionと同じようにベクトル化が可能ですが、最初に64bit乗算が必要になるため、レジスタを分割するタイミングが変わります。

参考にした記事

Mizarさんの記事

  • Montgomery剰余乗算のアルゴリズムはここで初めて知った
  • URLは記事中にあるので割愛

えびちゃんさんの記事