競プロ備忘録

競プロerの備忘録

技術室奥プログラミングコンテスト#6 Day1D - ABS SUM

コンテスト名が長すぎる…

解法

取り得る区間の累積和の絶対値の総和を求めよと言われている。

まず配列 Aについて区間 [l, r)の累積和の求め方を考えると、 rまでの累積和から、 l- 1までの累積和を引けば求まる。
つまり、 1から iまでの累積和を SA _ {i}と置けば、

 \sum _ {i=l} ^ {r} {A _ {i}} = SA _ {r} - SA _ {l - 1}

となる。

ところで、絶対値を求めよと言われているので、以下のように場合分けする。

 \displaystyle
|\sum _ {i=l} ^ {r} {A _ {i}}| = \left\{
\begin{array}{ll}
SA _ {r} - SA _ {l - 1}   &   (SA _ {r} \ge SA _ {l - 1})  \\
SA _ {l - 1} - SA _ {r}        &   else
\end{array}
\right.


この式をもとに SA _ {i}本位で考えると、

  •  i未満のインデックス向けには、自身より大きいものの数だけ SA _ {i}を引き、自身以下のものの数だけ足す
  •  i以上のインデックス向けには、自身より大きいものの数だけ SA _ {i}を引き、自身以下のものの数だけ足す

となることがわかり、結果的には、累積和配列全体をソートし、自身以上のインデックスの分引く、自身未満のインデックスの分足す、とするだけで答えが出ることがわかる。

実装例は以下の通り。
Submission #49840787 - 技術室奥プログラミングコンテスト#6 Day1

あとがき

最初解いたときには、「結果的には...」のところが気づけなくて、座圧+Fenwick木で自身の前後の自分より大きい/小さい要素の数をクソ真面目に数えていました。
記事書きながら思考まとめる最中で、あれこれ条件もっとシンプルにできるんじゃないか、と気づきました。備忘録書くなかでも発見はありますね。

ABC338D - Island Tour

本番で解けなくてひどい目に…見た目で無理そうと判断して飛ばしたのがチンパンプレーだったという話で、あとで考えたらそう難しくなかったのが余計に痛すぎる。

Twitter見てたら遅延セグ木とかいもす法とか言ってる人が多いですが、差分更新したほうが簡単じゃないかということでその解法を。

解法

最初から N 1の間がつながっていないとする。この場合は、 X _ {i} X _ {i+1}の差の総和がそのまま答えになる。

これを初期値として、橋を切り離す位置を後ろにずらしながら差分更新していくことを考える。橋を切る位置が後ろに一つずれるというのは、 1番目の島が切り離されて N+1番目の島として後ろに接続されることと同じである。

なので、 X _ {i} 1であるような iについて、 X _ {i-1}, X _ {i+1} 1の距離を、 N + 1との距離に置き換えればよい。置き換えた後は、 X _ {i} N + 1に置き換え、次の島を見ていく。
以後、頭の島は 2, 3, 4...と変化していくので、 Nまでこれを繰り返すだけで良い。

最終的な答えが出るまでに Xに登場する島を最初から最後までなめる必要があるので O(N)回ループが回り、各 X _ {i}は1回ずつチェックされるので、全体で O(N + M)で答えが出る。

実装例は以下。
Submission #49748310 - AtCoder Beginner Contest 338

アセンブリ言語でABS (AtCoder Beginners Selection) を解く

AtCoderでは2023年の言語アップデートでAMD64向けのアセンブリ言語(の中でもNASM)による提出ができるようになりました。
これで問題解いてみたいな~とは思っていたもののなかなか手が出ていませんでしたが、ABS程度ならいけるかと思い立って解いてみました。

ちなみに、先駆者がいるか提出一覧を調べてみたところ1名はいたようです。なので、初制覇というわけではないです。

以下で解法を書きたいわけですが、まあ高級言語ならやるだけで済んじゃうので、普通にメモ書いてもつまらないです。
どうせなら解いててハマったところとか書いていこうかなと思います。

NASMの文法とかアセンブリ言語の約束事など

一応述べておいた方が良さそうなので、最低限知ってなきゃいけないことだけ書きます。私は専門家ではないので、正しさは保証しません。

NASMについて

アセンブリ言語と一口に言っても、アーキごとに命令セットが違うのは当然として、アーキが同じでもアセンブラごとに文法が異なります。
有名どころでは、GAS記法とIntel記法というのを聞いたことがあるかもしれません。

AtCoderで使えるのは色々あるアセンブラのうち、NASMというアセンブラになります。
NASMの文法が知りたい場合は、公式のマニュアルを見るのが確実です。

NASM - The Netwide Assembler

コードの構造について

アセンブリ言語の場合は、コードの構造がほぼそのまま実行ファイルの構造と結びつくので、実行ファイル中の命令やデータの配置を知っておく必要があります。
実行形式によって配置や構造は異なるのかもしれませんが、私はELF64しか知らないですし、AtCoderの環境も多分そうなのでその前提でいきます。

AtCoderでコードを実行するにあたっては最低でも以下の4つの領域があるということを知っておく必要があります。

  • テキスト領域
    • 命令が格納される領域。.textというセクション名で定義。
  • データ領域
    • 何らかのデータで初期化された領域。.dataというセクション名で定義。
  • BSS領域
    • 0初期化されたデータが格納された領域。.bssというセクション名で定義。
  • スタック領域
    • いわゆるローカル変数の確保やレジスタの保存を行う領域。コード中では特にセクション名を付けなくてよい。

これらはリンカスクリプトでオブジェクトファイル中のどこにどの程度のサイズで配置するかを制御できますが、GCCC言語のコードをコンパイルをするときにいちいちリンカスクリプトを作らなくていいように、いつもはリンカがよしなにやってくれます。
例えば、ldを使っている場合はld --verboseなどとすれば、デフォルトでどういうリンカスクリプトが適用されているか確認できます。

エントリポイントについて

ネットでコードの書き方について調べると、メインのコードの始まりのラベルがmainだったり_startだったりすると思います。
これはcrt0.oやcrt1.oのようなプログラムの起動処理を実行するルーチンがリンクされているかどうかで違います。(と、私は認識してます。違ったら教えてください)
AtCoderの環境ではリンクされますから、エントリポイントはそちら側にあり、私たちはmainから始める必要があります。

ABIについて

Application Binary Interfaceです。プログラム間のバイナリレベルでのインタフェースを定義しています。

Linux環境ではおそらくSystem V ABIというのが使われていると思います。

https://www.uclibc.org/docs/psABI-x86_64.pdf

私も細かいことは詳しく知りませんが、差し当たって知っていなければいけないのはレジスタの使い方でしょう。
上記のドキュメントで言えば、「Figure 3.4: Register Usage」がそれにあたります。

AMD64のプロセッサではGPRが16本ありますが、それぞれ引数渡し用、返り値用、一時用などの用途や、ルーチン呼び出し時のレジスタ保存の役割分担などが決まっています。
割とこれハマりがちだったので、頭に入れておいたほうが良いです。

命令セットについて

当たり前ですが、AMD64向けの命令しか使えません。

Intelからマニュアルが出ていますから、これを参照しましょう。具体的には、「Intel64 and IA-32 Architectures Software Developer's Manual Volume 2 (2A, 2B, 2C, & 2D) Instruction Set Reference, A-Z」です。

各問題の解法

ようやくですが、解法です。

ちなみに、最初はシステムコールとHW命令だけで突っ張るつもりだったのですが、ちょっと無理ゲーなので、入出力だけはライブラリ使います。
libcがリンクされているからだと思うのですが、extern printfprintfが使えます。

PracticeA - Welcome to AtCoder

提出例:Submission #49418142 - AtCoder Beginners Selection

初手なので少し細かめで。

scanfprintfを使うつもりなので、externでそれを宣言します。また、フォーマット用の文字列が必要ですが、これはデータ領域に格納します。
fmtpのうしろの10って何?って思われるかもしれませんが、Asciiコード表を見ると、10進数の10は改行コードであることがわかります。なので、C言語で言えば"%d %s\n"とおなじです。最後に0を付けているのは、NULL終端のためです。

入力のsは文字列なので、配列が必要です。これはどの領域にとってもいいのですが、解いたときはBSS領域に格納したようです。
データ領域のときはdbが8bitサイズ、BSS領域のときはresbが8bitサイズとなります。なので、s resb 1024は、C言語で言えばstatic uint8_t s[1024];とかと同じです。

.textセクションに入って、まずはrbpをスタックに保存します。
その後、スタックポインタを引いてスタック領域にローカル変数保存用の領域を確保します。32という数字に特に意味はないです。12で十分ですが、まあ多少多めにとっても問題はないです。

次はscanfのための引数をレジスタに渡します。
ABIによれば、引数をレジスタ経由で渡す場合、1つ目から順にrdi, rsi, rdx, rcx, r8, r9を使う必要があります。なので、fmtのアドレスをrdiに渡して、それ以外は確保したスタック領域から順に4Byte区切りでアドレスをレジスタに渡していきます。
fmtのアドレスを取得するときに[rel fmt]としなくてはならないのは、このコードがPIE(位置独立実行形式)としてリンクされるからのようです。この場合、rip相対で参照する必要があり、rip相対参照は[rel fmt]とすることでできます。

callの前にやっているのは、raxのクリアです。xor eax, eaxはよくあるレジスタのクリア方法で、eaxをクリアすればraxまでクリアされます。
なんでクリアするのかですが、ABIでraxは可変長引数を渡すときにベクトルレジスタをいくつ使うかの情報を渡すのに使われることになっているからです。試してみるとなんかクリアしなくても動くのですが、scanfprintfも可変長引数を要求する関数なので、一応クリアしておきます。

ついにscanfを呼びます。共有ライブラリのルーチンは、PLT (Procedure Linkage Table)を経由して呼び出されるため、scanfcall scanf wrt ..pltとしなくてはなりません。これはprintfも同様です。(NASMのリファレンスの「10.2.5 Calling Procedures Outside the Library」とかを参照するとよいです)

入力ができたら、まずはa + b + cを作ります。スタックの若いほうから4Byteおきに数値を格納したので、順にesiに格納して足し込んでいきます。
a + b + cが作れたら、rdxsのアドレスを格納します。

ここまででprintfの第2,3引数は格納できているので、あとは第1引数としてfmtpのアドレスを格納して、raxをクリアしたら、printfを呼び出します。

最後はスタックポインタとベースポインタをもとに戻して、retで戻ります。
スタックポインタに足す数字がめちゃくちゃでトラブるというのはありがちなので気をつけましょう。(だから32固定にしているというのもある)
また、ここでraxをクリアしないと、終了ステータスが0以外になってしまい、REで死にます。ちゃんとクリアしてからretしましょう。

ABC086A - Product

提出例:Submission #49519868 - AtCoder Beginners Selection

掛け算と偶奇判定と条件分岐が出来ればよいです。

掛け算にはmulが使えます。これは符号なし整数の掛け算を行う命令で、符号付き整数の場合にはimulという命令があります。
オペランドは1つだけ、レジスタかメモリをとれます。もう一方の項はどうしたって話ですが、Instruction Set Referenceを読めばわかる通り、64bitの計算ならraxが暗黙的に指定されます。
計算結果は、64bitの場合、下位ビットがrax、上位ビットがrdxに格納されます。なので、rdxが破壊されます。rdxに意味のある値が入っている場合には、事前に退避させる必要があります。

偶奇判定ですが、これは結果が入っているraxのLSBが1かどうかを確認すればよいでしょう。C言語でいえば、(rax & 1) == 1をチェックできればよいです。
これにはtestが使えます。これはレジスタ、メモリ、即値を2つ引数に取り(ただし、即値は1つまで)、ANDをとった結果に応じてフラグレジスタを変化させます。
test rax 1であれば、raxの最下位ビットが0である場合にZFが1にセットされ、そうでないとき0にセットされます。

条件分岐には、je, jne, jg, jl...(その他多数)を使うことができます。詳細はInstruction Set Reference参照ですが、今回使ったjeは、ZFが立っているときに指定したラベルへ飛びます。
なので、test rax 1のあとje evとすることで、raxと1のANDが0の場合はev("Even"を出力するほう)へ飛び、そうでないときはそのまま"Odd"を出力できます。
"Odd"を出力した後は、"Even"を出力する処理の後ろに設置したepilog:ラベルに無条件ジャンプします。

結果として偶奇判定と条件分岐を実現することができます。

ABC081A - Placing Marbles

提出例:Submission #49520205 - AtCoder Beginners Selection

3Byte入力して、すべて足し込んでから144 (Asciiコードで'0'は10進数の48)を引けばよいです。

silrsiの最下位1Byteです。メモリから1ByteとってくるときはBYTEを指定すればよいです(小文字でもOK)。2ByteならWORD、4ByteならDWORD、8ByteならQWORDです。

ABC081B - Shift only

提出例:Submission #49520373 - AtCoder Beginners Selection

Trailing Zerosを数えて、Minをとって出力すればよいです。加えて、地味に面倒くさいループ処理です。

Trailing Zerosを数えるのには、tzcntという命令があります。2つのオペランドが必要で、1つ目はレジスタ、2つ目はレジスタかメモリです。
2つ目のオペランドのTrailing Zerosを数えて、結果は1つ目のオペランドに指定したレジスタに格納されます。
計算対象が0だと未定義動作になりますが、制約ではそのような入力はないので問題ないです。

Minを取るための命令はたぶんないと思います。なので、頑張って条件分岐で求めます。 提出例ではcmpという命令を使っています。testとの違いは、cmpは1つ目のオペランドから2つ目のオペランドを引き算した結果でフラグレジスタを設定するところです。
jgcmpの結果、1つ目のオペランドが大きかった場合に指定されたラベルにジャンプします。(厳密には、OFとSFが等しいかつZFが立っていないときにジャンプするが、subの結果がこうなる場合は1オペランド目が大きい)

ループのパターンは決まっていて、まずカウンタ用のレジスタのクリア、その直後にラベルを付け、ループ末尾でincaddを使ってカウンタを更新、最後に継続条件を判定して、満たす場合はループの頭のラベルにジャンプします。
今回はC言語でいうところのfor (int i = 0; i < n; i++)がしたいので、これをcmpjgを使って実装します。

ABC087B - Coins

提出例:Submission #49520715 - AtCoder Beginners Selection

ただただループ処理が面倒くさいだけなので、頑張って実装します。

内側ループでは割り算が必要ですが、整数除算にはdivが使えます。
これもmul同様に1つだけオペランドを取りますが、面倒なことに即値は指定できないので、レジスタに50を格納してから計算します。
暗黙に指定されるオペランドrax, rdxの2つで、rdxが上位、raxが下位です。結果は、商がrax、剰余がrdxに格納されます。
実はハマりどころですが、連続でdivするときにはrdxをクリアしないと、前の計算結果の剰余がそのまま次の計算の暗黙のオペランドの上位ビットに含まれてしまい、計算結果が大きく狂います。
今回はXが50の倍数であることによって、rdxが必ず0になるためクリアは不要ですが、この後の問題でこれによってドハマりしました。気をつけましょう。

ABC083B - Some Sums

提出例:Submission #49386518 - AtCoder Beginners Selection

桁和を求めるなら普通ループ書くと思いますが、アセンブリ言語でループ書くのは面倒ですし、高々5ケタしかないので、ループ展開しちゃいます。

divを使えば商と剰余が手に入るので、10で割りながら剰余を足し込みます。
先ほど述べたように、rdxをちゃんとクリアしないと、結果がめちゃくちゃになります。

ABC088B - Card Game for Two

提出例:Submission #49398501 - AtCoder Beginners Selection

ソートして交互に足し引きするだけの問題ですが、アセンブリ言語で書くのはあまりに苦痛な問題です。
libcが使えるならたぶんqsortが使えると思いますが、ポリシー的にびみょいので、ここは気合で実装します。

Nの制約が小さいので、バブルソートや選択ソートも余裕で通りそうです。ここでは実装が簡単そうな選択ソートをすることにしました。
実装は気合で二重ループと条件分岐を書くだけです。

ちょっとした工夫としては、選択された値をメモリに書き戻す必要はないので、レジスタに値を格納した後はそれをスワップし続けます。
スワップ用の命令としてはxchgというものがあり、レジスタとメモリの値のスワップができます。暗にlockが指定されるためパフォーマンスが悪いらしいですが、まあ今回の用途では誤差です。

もう一つ面倒なのが、配列のアクセスの際、ループカウンタのレジスタを使って、例えば[rel s + rax]のように参照する箇所を指定することができない点です(即値なら可能)。
どうすればいいかですが、これはいったんレジスタlea rbp [rel s]などとしてアドレスをロードし、このレジスタに値のサイズを加算しながら参照すればよいです。実は何かよいやり方があるのかなあと思っていますが、今のところ見つかっていないので、これでしのぎました。

ABC085B - Kagami Mochi

提出例:Submission #49398793 - AtCoder Beginners Selection

ソートが本質みたいな問題なので、前の問題のコードをそのまま使いまわします。

あとは隣接項の大小関係で場合分けして答えを加算するだけです。

ABC085C - Otoshidama

提出例:Submission #49398990 - AtCoder Beginners Selection

気合で二重ループ書くだけの問題です。ここまで述べたハマりどころ意外には特にハマりどころがありません。

ABC049C - 白昼夢

提出例:Submission #49417731 - AtCoder Beginners Selection

実質ボス問です。

高級言語なら配列作ってDPやるだけだろ?って話ですが、文字列操作のライブラリは使いませんから、あらゆる操作が苦痛な問題です。strlenstrcmpも使えないので、ループや何らかのハックで頑張るしかありません。

まずは文字列のサイズがわからないとどうしようもないので、頑張ってループを書いて確かめます。
NULL終端されているはずですから、文字の0判定を行い、終端に辿り着いたら抜けます。
0判定には、同じレジスタを2つのオペランドに指定してtestするという方法が使えます。

その次は本題のDPです。
カウンタを回しながら、まずそこまでをdream, dreamer, erase, eraserで構成できるのか確認したのち、そこからこの4つで文字列を伸ばせるか確認します。
伸ばせるかの確認が面倒なところですが、いずれの文字も8文字よりは小さいことを考えると、64bit整数の一致判定で何とかなります。
まずはレジスタにQWORDとしてメモリをロードし、文字列長でマスクをかけ、文字列が構成するビットパターンにマッチするか確認すればよいです。

andcmpは即値をとれるのですが、Instruction Set Referenceを見ればわかるとおり、64bitの即値は取れません。なので、movレジスタにいったん即値をロードし、これをオペランドにする必要があります。

ABC086C - Traveling

提出例:Submission #49521145 - AtCoder Beginners Selection

最終問題です。ここまで来てしまうと、あんまりハマりどころもないので、説明すべき点もないです。

入力を取りながら、前の座標とのマンハッタン距離を求めて時間を引き、マイナス判定と偶奇判定を行い、どちらかの判定に失敗した時点でbad:に飛べばよいです。

マイナス判定については、いずれも直前のsubでマイナスになったらSFが立ちますから、jsjnsで目的のラベルに飛ばせます。

あとがき

デバッグがとにかく辛かったです。なんせprintfするだけでも、引数の準備が微妙に面倒くさかったり、最後にしか出力しない前提で書いているコードだと自分で保存しないといけないレジスタを何の考えもなく使っていてprintfの内部でそいつが破壊されて結果が意味不明になったりということがあります。
手元で環境用意してデバッガでステップ実行するのが手早い説はありますが、そこまでする根性がなかったです。

ただ、ABIやら実行形式やらメモリ上のデータの配置やら、高級言語ばかり書いているとわからないがちなところを抑えないと書けないだけあって、とても勉強になりました。
今後もたまにはアセンブリ言語触っていきたいですね。

ABC333F - Bomb Game 2

本番全くわからなくて焦った…

解説読んでもあんまりピンと来なかったのですが、半日考え直してようやく理解できたので記事にします。

解法

無限に誰も排除されないことがあるけどどうしよう...となるので、とりあえずサンプル1がなんでこれでいいのかを解き明かさないといけない。

まず1周目で1人目が排除される確率は \frac{1}{2}、2人目が排除される確率が \frac{1}{2}\cdot\frac{1}{2}=\frac{1}{4}、2周目に回って、1人目排除が \frac{1}{2}\cdot\frac{1}{2}\cdot\frac{1}{2}=\frac{1}{8}...と続き、それぞれの総和がそれぞれの排除される確率であることがわかる。

もうちょっと考えると、1人目が排除される確率は \frac{1}{2}(1+\frac{1}{4}+\frac{1}{16}...)となって、2人目が排除される確率は \frac{1}{4}(1+\frac{1}{4}+\frac{1}{16}...)となることもわかる。

ここまでくると、 x = \frac{1}{4}として、 1 + x + x^{2} + x^{3} + x^{4} + ...みたいなものを計算したくなってくる。
 1 + x + x^{2} + .... + x^{n} = \frac{1-x^{n+1}}{1-x}であって x = \frac{1}{4}なので、雑に考えると nを無限に飛ばせば x^{n+1} 0になって消えるでしょう、みたいな気持ちになると、上の式は \frac{1}{1-x}と等しいだろうと予想できる。
実際色々ググってみたりするとこれが正しいことはわかる。

これをもとにサンプル1の話に戻ると、1人目が排除される確率は \frac{1}{2}\cdot\frac{4}{3}=\frac{2}{3}、2人目が排除される確率は \frac{1}{4}\cdot\frac{4}{3}=\frac{1}{3}となり、それぞれを1から引けば他方が最後まで残る確率がわかるので、出力が正しいということがわかる。

もう少し抽象化して考えて、 i人いる状態で、先頭から j人目( 0 \le j \lt i、以降単に jと呼ぶ)が排除される確率はいくらだろうということを考えてみる。

1周目で排除される確率は、 jより前が全員 \frac{1}{2}の確率で生き延び、 j \frac{1}{2}の確率で排除されるため、 \frac{1}{2^{j+1}}となる( jが0から始まることに注意)。
1周を誰も排除されず周り切る確率は明らかに \frac{1}{2^{i}}なので、2周目以降 jが排除される確率は、 \frac{1}{2^{i+j+1}}, \frac{1}{2^{2i+j+1}}, \frac{1}{2^{3i+j+1}}...となり、これらの総和が jが排除される確率になる。

これを式にすると、

 \sum^{\infty} _ {k=0} {\frac{1}{2^{ki+j+1}}}

となり、整理すると、

 \begin{eqnarray}
\sum ^ {\infty} _ {k=0} {\frac{1}{2^{ki+j+1}}} &=& \frac{1}{2^{j+1}} \sum ^ {\infty} _ {k=0} {\frac{1}{2^{ki}}} \\\\
&=& \frac{1}{2^{j+1}} \sum ^ {\infty} _ {k=0} {(\frac{1}{2^{i}}) ^ {k}} \\\\
&=& \frac{1}{2^{j+1}} \cdot \frac{1}{1 - \frac{1}{2^{i}}} \\\\
&=& \frac{1}{2^{j+1}} \cdot \frac{2^{i}}{2^{i} - 1} \\\\
&=& \frac{2^{i-j-1}}{2^{i} - 1}
\end{eqnarray}

となることがわかる。長いので、以降はこの式を p _ {i, j}と表すことにする。

(ここまではコンテスト中に分かっていた点で、以降はコンテスト後の考察)

これを使って、以下のようなDPを考える。

 dp _ {i, j} := i人残っている時点で j人目であるような人が最後まで生き残る確率

まず、1人しか残っていない時点で0番目であるような人が生き残る確率は 1なので、 dp _ {1, 0} = 1とわかる。

これを昇順に計算することを考えるが、遷移の考え方がちょっとこんがらがる。
たとえば、 i人いる状態で j人目だった人が、 i-1人いる状態で 0人目になるにはどういう遷移をするか考えると、これは j - 1人目が排除されたときなので、

 dp _ {i, j} += dp _ {i-1, 0} p _ {i, j-1}

と遷移できる。同じように 1人目、 2人目...と遷移することを考えると、遷移式は以下のようになる。

 \begin{eqnarray}
dp _ {i, j} = &&dp _ {i-1, 0}p _ {i, j-1} + dp _ {i-1, 1}p _ {i, j-2} + dp _ {i-1, 2}p _ {i, j-3} + ... \\\\
&+& dp _ {i-1, j-1}p _ {i, 0} + dp _ {i-1, j}p _ {i, i-1} + dp _ {i-1, j+1}p _ {i, i-2} + ... \\\\
&+& dp _ {i-1, i-2}p _ {i, j+1}
\end{eqnarray}

右辺のそれぞれの項の dpの2つ目の添え字が「 jが生き残った i-1人の中で何人目になるか」、 pが「 jがそこに移動するためにしかるべき人が排除される確率」であると考えると理解しやすい。

ここまでで O(N ^ 3)解となるが、これを高速化するには、累積和を用いることができる。

具体的には、 dp _ {i-1, *}の前後から、適切に pをかけて累積和をとる。先頭 j項は前からの累積和を単に足すだけで良い。以降の i - j項は後ろからの累積和を用いるが、 jの値は動き続けるので、少し工夫が必要。
 p _ {i, j} = \frac{2^{i-j-1}}{2^{i} - 1}を思い出すと、 jが増えるたびに分母の次数が1つずつ下がっていくことがわかる。したがって、 p _ {i, j}から p _ {i, j+1}を導くのは \frac{1}{2}をかければよく、 p _ {i, j + k} (\frac{1}{2}) ^ kをかけて導けるから、累積和に適切な回数 \frac{1}{2}をかければよい。

累積和の部分は言葉ではイマイチ上手く伝えられないので、以下の実装例から。

Submission #48642405 - Toyota Programming Contest 2023#8(AtCoder Beginner Contest 333)

あとがき

こういうゴリゴリ計算するのを要求される問題は破滅度が高いので大嫌いなんですが、こういうのを倒せないと黄色にはなれないんだなあと改めて思いました。(小並感)

なんでもできるRollingHashを作りたい

ABC331Fでロリハをセグ木にのせろと言ってるようなもんじゃんっていう問題が出たわけですが、その場で解けず悲惨な目にあいました。

なので1点更新のロリハをライブラリ化したわけですが、最近Link Cut Treeのライブラリ化を試みていたりと地味に平衡二分木がマイブームで、平衡二分木にロリハ乗せればよさげなものができそうだと気づきました。

というわけなのでそのライブラリ整備の備忘録です。ちなみに、理論的な部分は多く語れません。(例えば、Splay Treeの計算量とか、なんでそうなるかは全くわかってない)
実装記録だと思ってください。

Rolling Hash

Rolling Hashについてはだいたいわかっているものとしますが、この記事では文字列 s _ 0 s _ 1 s _ 2 ... s _ {N-1}とある基数 b、法 Mを用いて、以下の式で表されたハッシュ値を導くアルゴリズムということにします。

 (b ^ {N-1} s _ 0 + b ^ {N - 2}s _ 1 + b ^ {N - 3}s _ 2 + ... + bs _ {N - 2} + s _ {N-1}) \bmod M

 Mにはクソデカなメルセンヌ素数を使うと良いとか、基数はランダムなほうが良いとか、色々ありますが、この記事ではそんな重要な話ではないので割愛です。(「安全で爆速なRollingHashの話」とかググると良いでしょう)

文字列全体のハッシュ値は以上のように計算できるわけですが、部分文字列のハッシュ値を切り抜きたくなることもあるでしょう。
例えば s _ l s _ {l+1} s _ {l+2} ... s _ {r - 1}を切り抜くときは、以下のような値を導出したいわけです。

 b ^ {r - l} s _ l + b ^ {r - l - 1} s _ {l + 1} + b ^ {r - l - 2} s _ {l + 2} + ... + s _ {r - 1}

長さが r - lに変わったので、初項の基数の肩に乗る指数も r - lになっています。
これを計算する方法としては、メモとして s[0..i]のハッシュ値 i番目とする配列( memoとする)を持っておき、

 memo _ {r - 1} - memo _ {l - 1} b ^ {r - l + 1}

とすることで可能です。お気持ちとしては、 memo _ {l - 1} s _ {l - 1}の項の基数は 1で、上の式からもわかる通り memo _ {r - 1} s _ {l - 1}の項の基数は b ^ {r - l + 1}なので、基数を合わせて引き算すれば l - 1項目より前はゴッソリなくなりそうだなと思えそうです。

1点更新できるRolling Hash

先述のようなRolling Hashで1点更新を行うとすると、 memoの再構築が必要です。
例えば、 s _ 0 tに変えたいとすると、 s _ 0 memo _ 0では s _ 0として、 memo _ 1では bs _ 0として、 memo _ iでは b ^ {N - 1 - i} s _ iとして足されているので、それぞれを引いたうえで t,  bt,  b ^ {N - 1 - i} tを足しなおす、などとしなくてはなりません。このままでは無理筋なので何とかしようと考えると、真っ先に思い浮かぶものにセグメント木があります。

セグメント木に乗せる方法は色々あるかもしれないですが、私が思いついた方法は、区間がカバーする部分文字列のハッシュ値と最高次の基数を乗せる方法です。

最下段の i番目のノードでは、 s _ {i}1文字分のハッシュ値と、最高次の基数 1を持てば良いです。
内部のノードでは子ノードの値を用いて集約値を計算する必要がありますが、親と左右の子の(ハッシュ値、基数)のペアをそれぞれ (H _ p, B _ p), (H _ l, B _ l), (H _ r, B _ r)と表すと、

 (H _ p, B _ p) = (H _ {l}B _ {r}b + H _ {r}, B _ {l}B _ {r}b)

となります。左の子の最低次の基数は 1で、これを右の子の最高次の基数 B _ rより次数1個分大きくしないといけないので、左の子のハッシュ値 B _ {r}bをかけて右の子のハッシュ値と足せばよいです。

見てわかる通り演算は可換じゃないので、ちょっと具体的には思いつかないですが演算が可換なことを前提にした雑な実装をしていたりすると、破滅するかもしれないので気を付けたほうが良いでしょう。

これで1点更新可能なRolling Hashが出来たので、ABC331Fは撃破できます。

1点更新と反転ハッシュ取得可能なRolling Hash

これは簡単で、逆順のRolling Hashをもう1つ持てば良いです。ただ、反転は添え字でバグらせがちなので、できればデータ構造1つで殴りたいです。

これも簡単に解決可能で、セグ木のノード1つに正順序の計算結果と逆順序の計算結果の両方を持ち、逆順序の計算では先述の式の右辺の l, rを反転させればよいです。
ようするに、こんな感じです。(modの計算はめんどくさいの省略です)

trait Monoid {
    fn id() -> Self;
    fn op(l: Self, r: Self) -> Self;
}

#[derive(Clone, Copy)]
struct Node {
    f: (u64, u64),
    r: (u64, u64),
}
impl Monoid for Node {
    fn id() -> Node {
        Node {
            f: (0, 1),
            r: (0, 1),
        }
    }
    fn op(l: Node, r: Node) -> Node {
        Node {
            f: (l.f.0 * r.f.1 * b + r.f.0, l.f.1 * r.f.1 * b),
            r: (r.r.0 * l.r.1 * b + l.r.0, r.r.1 * l.r.1 * b),
        }
    }
}

正順序のハッシュが欲しいときはNode.fを、逆順序のハッシュが欲しいときはNode.rを使えばよいです。

とはいっても構造体のサイズがバカでかくてちょっと嫌ですよね…計算で逆順を計算できたりするのかなーなどと考えていたりしますが、今のところ解決法は見つかっていません。
まあパフォーマンス的に若干微妙かもしれませんが、致命的なほどではないはずです。

ちなみに、ここまでやると回文判定もおまけで可能になります。正順序のハッシュ値と逆順序のハッシュ値が一致しているか確かめればよいです(それはそう)。

1点更新挿入削除と任意の範囲の反転が可能なRolling Hash

これが本題です。 もうお察しかと思われますが、平衡二分木に乗せます。私はSplay Treeしか書けないので、それに乗せました。

本当に乗せるだけで終わりなのですが、実装が雑だと普通に使い物にならないレベルで激遅なので、多少実装の工夫が必要です。

雑にSplayしまくらない

Splay木はとりあえずSplayしておけば一応それっぽくは動くのですが、Rolling Hashでは更新がとにかく重いです。雑にSplayしまくるとその分コストになってのしかかるので、用法用量は適切にすべきでしょう。

update (eval) の中の乗算の回数をとにかく減らす

反転可能にしているがゆえに、正順序、逆順序のハッシュと基数を持っていると思いますが、これの更新は割とカオスになりがちです。
なので別で関数やメソッドに起こしてキレイにしたくなるものですが、それによって乗算が増えてしまうとまたこれがコストになります。

基数は2つ持つ必要はない

正順だろうが逆順だろうが、集約値の基数の最高次数は変わりません。
基数の更新のために乗算が増えますし、メモリも無駄なので、ノードで基数を2つ持つ必要はないです。 これは1点更新と反転ハッシュ取得可能なRolling Hashでも同様です。

update (eval) の呼び出し回数を最適化する

これです。 エッ!? 平衝二分木の update, push (eval, propagate) のタイミングがわからないですって? フッフッフ…… #競技プログラミング - Qiita

実のところめちゃくちゃ重要で、これが一番性能への影響が強かったです。
部分木の回転の際、根が正しく集約値を持っていることが保証できるなら、新しい根には元の集約値をそのまま代入しても結果は変わりません。これによってupdate (eval) の呼び出し自体を減らすことができます。

私が実装した限り、update (eval) は最悪8回程度の乗算を必要とするので、これが丸ごと消えるのは大きいです。

雑にメモリ確保しない(未実装・未検証)

未検証なのですが、雑にメモリ確保しているとキャッシュミス連発でひどいことになっているかもしれないなーなどと考えたりしています。
回避法としては、グローバルに大きめのVecを確保して、そこにノードを固めて置いておくとかが考えられそうです。

とはいっても、RustだとグローバルにVecを確保すると必然的にthread_local!+RefCellやstatic+Mutexのお世話になるわけで、そのアクセスコストとどっちがマシなのよっていう疑問もあり、実装も面倒なので、まあこのままでいいかなとも思っています。

実装例

これです。

https://github.com/tayu0110/tayu-procon/tree/master/string/src/rolling_hash

あとがき

ちょっと最後の方駆け足で雑になっちゃいましたが、まああんまり細かく書いてもな...という話なので、勘弁してください。(ちゃんと書くとほぼSplay Treeの実装記録になってしまう)

最初はちょっと大きめの制約ですぐTLEしてしまうダメな子だったのですが、工夫すればするほどどんどん性能が良くなるライブラリは作っていてとても楽しいです。ダメな子ほどかわいがりたくなるというやつかもしれません。

記事中や実装例で嘘が書いてあるのを見つけた方は、ぜひお教えください。

JOI 2023/2024 二次予選 B - 買い物 2 (Shopping 2)

パッと見であんまりよくわからなかったけど、解けてみると面白かったので。
こんなんセグ木に載るんかーといった感じでした。

解法

早い話が、全体の累積和と、各商品種( A _ iで分類される種類)の累積和をそれぞれ用意することができれば、各クエリに O(1)で回答することができる。( [L, R)の範囲内の全体の累積和から、種類 Tの同じ範囲内の累積和の半分を引けばよい)
これを愚直に実装すると、たぶん小課題の1, 2, 3あたりまでは点数がもらえるはず。

満点解法もこれをベースにすることで作れる。
全体の累積和は、追加制約があろうがなかろうが簡単に作れる。しかし、各商品種の累積和を作るというのがメモリ制限でも実行時間制限でも厳しい。(それぞれ O(N * min(M, N))程度の時間/空間が必要)

そこで、辞書をセグメント木に乗せることを考える(HashMapとかunordered_mapとかdictとか呼ばれるもの)。
辞書には、そのノードがカバーする範囲の商品種をキー、商品種ごとの定価の合計を値として設定する。クエリの際は、 Tをキーに定価の合計値を辞書引きながらトップダウンに足し込んでいけばよい。

クエリの際の計算量はそれほど問題ではなくて、各ノードの辞書の検索は O(1)と期待できるので、全体でも O(log N)で捌くことができる。
ちょっと引っ掛かりそうなのはセグメント木の構築の部分の時間計算量と空間計算量だと思う。

構築の方法としては、左の子の辞書を自身にクローンしたうえで、右の子の辞書の内容を1つずつ愚直に反復して、自身に追加していくという方法がとれる。(重複要素があれば、足し込む)
時間がかかるのは子のクローンとイテレートで、これは子ノードの要素数 nならいずれも O(n)程度はかかる。
全ノード要素数が最大となるのは、 N個の商品の種類がバラバラのときである。具体的に要素数を見積もると、ある段の1つのノードの要素数 nなら親ノードの要素数 2nとなるが、段全体のノードの数は半分に減るので、全体の要素数は変化しない。
また、1段分の全体としてのマージの計算量は、段全体の要素数からわかる(なぜなら、クローンとイテレートの計算量は要素数からわかるから)ので、結局 O(N)である。
したがって、上記のような構築方法をとった場合、全体として O(NlogN)程度の時間計算量となる。

先述の通り、各段 N要素で、 logN段あることから、空間計算量も O(NlogN)程度で収まっていることがわかる。

よって、全体としては時間計算量 O(NlogN + QlogN)で実行時間制限に間に合わせることができる。

実装例は以下の通り。普通はクエリで範囲を表す変数(lとかrとか)以外の変数を渡すことはないので、ライブラリもそうはなっておらず、その場で手書きすることに。
とはいってもセグ木ソラ書きは得意なので、あんまり面倒ではなかった。

Submission #48428078 - JOI 2023/2024 二次予選 過去問

あとがき

よく考えてみるとこのセグ木って、各商品種の「必要なところだけ作るセグ木」を、辞書を使って重ね合わせて1つのセグ木にまとめたってな感じになっているわけですね。

なので、ちゃんと作りこまれた動的セグ木を持っていれば、配列に動的セグ木をのせるだけで楽に解けるかもしれません。(めんどくさいのでやらないですが)

爆速なNTTを実装したい

それぞれ長さ N, Mであるような配列 a, bに対して、 c _ {k} = \sum _ {i+j=k} {a _ {i} b _ {j}}といった形の演算を畳み込みと呼び、多項式の乗算の実装等に使えます。

で、そのライブラリのverify用の問題がyosupo judgeにあります(https://judge.yosupo.jp/problem/convolution_mod)。

私はRust使いなのですが、ちょっと前まではRustの最速コード(私ではないです)は150msec以上かかっていたような記憶があります。C++最速のコード(全体の最速コードでもあります)は50msec弱くらいですので、だいぶ遅いです。

一応RustはCやC++並みのパフォーマンスを出せると謳われているのにこれはちょっと気に入らないですね?というわけで、頑張って高速化を目指したという話です。

タイトルが「爆速な畳み込み」ではないのは、結局畳み込みを実装するために内部で使用するNTTの高速化の話が中心になるからです。細かいことを気にしてはいけません(?)

畳み込みの話

細かい解説はできませんが、畳み込み定理というものがあります。本来はフーリエ変換に関連する定理らしいですが、離散フーリエ変換でも成り立ちます。
冒頭の式で言えば、 cの離散フーリエ変換は、 a, bの離散フーリエ変換を要素ごとに乗算したものと等しいという定理です。

したがって、離散フーリエ変換を高速で計算することができれば、

という手続きによって、2つのデータ列を畳み込むことができます。そんな訳なので、離散フーリエ変換をとにかく高速化しようということです。

DFTの話

DFT(Discrete Fourier Transform, 離散フーリエ変換)とは、フーリエ変換を離散的な時間・周波数領域で適用できるように改良した手法らしいです。
DFTは以下のような式で表されます。

 F(ω) = \sum _ {t = 0} ^ {N-1} {f(t) e ^ {- \frac{2 \pi tx}{N}i}}

 \piは円周率、 eネイピア数i虚数単位です。 Nは普通、データを格納した配列の長さでしょう。

変換後のデータ長を Mとすれば、各 ω = 0, 1, .., M-1に対して上の式を計算するので、愚直にやれば O(NM)の時間計算量となります。一般的に N \ne Mなことがあるのか知りませんが、以下 N = Mの前提で話を進めます。

DFTにはIDFT(Inverse Discrete Fourier Transform, 逆離散フーリエ変換)という逆変換があり、これは以下のような式で表されます。

 f(t) = \frac {1}{N}\sum _ {ω = 0} ^ {N-1} {F(ω) e ^ { \frac{2 \pi tω}{N}i}}

各データに掛けるeなんちゃらが共役な複素数になって、最後に \frac {1}{N}を乗じるだけということです。この \frac{1}{N}については、離散フーリエ変換の導出の過程で、正変換側につくことや、正変換・逆変換両側に \frac{1}{\sqrt {N}}としてつけることもあります。
ただ、畳み込み演算では正変換を2回、逆変換を1回をやるので、逆変換側に押し付けてしまったほうが得な気がします。多くの競プロの記事でも逆変換側に押し付けているので、それに倣うこととます。

(※このパートで1のN乗根について述べていますが、導出の過程を誤っているとTwitterでご指摘をいただきました。そのうち書き直すので読み飛ばしてください。ごめんなさい。)
どちらの変換でも eの肩に複素数を乗せた指数関数をかけていますが、よく見ると e ^ {2 \pi i}が混じっています。みんな大好きオイラーの公式から、 e ^ {i \pi} = -1なので、これは1に等しいです。さらに \frac {tω}{N}乗されているので、これは 1 N乗根を tω乗したものと言えます。

ここで、 1 N乗根 e ^ {- \frac{2 \pi}{N}i} W _ {N}と表すことで、正変換は以下のようにも表せます。

 F(ω) = \sum _ {t = 0} ^ {N-1} {f(t) W _ {N} ^ {tω}}

 1 N乗根は、複素数平面上の単位円周上を回転して再度 1に戻ってくるので、回転因子(Twiddle Factor)とも呼ばれます。逆変換における回転因子は正変換のそれと共役な複素数なので、逆方向に回転します。

DFTはこの回転因子を用いて、変換前の次数が N、変換後の次数が Mであれば、 M N列の行列によって表現することもできます。例えば、変換前後ともに次数4のDFTの変換行列は以下のようになります。


\begin{pmatrix}
W _ {4} ^ {0} & W _ {4} ^ {0} & W _ {4} ^ {0} & W _ {4} ^ {0} \\
W _ {4} ^ {0} & W _ {4} ^ {1} & W _ {4} ^ {2} & W _ {4} ^ {3} \\
W _ {4} ^ {0} & W _ {4} ^ {2} & W _ {4} ^ {4} & W _ {4} ^ {6} \\
W _ {4} ^ {0} & W _ {4} ^ {3} & W _ {4} ^ {6} & W _ {4} ^ {9}
\end{pmatrix}

FFTの話

FFT (Fast Fourier Transform, 高速フーリエ変換)はDFTを高速化するためのアルゴリズムです。

有名な手法としてはCooley-Tukey FFTがあり、これは次数を落としながら小さいサイズのDFTを行い、マージして元のサイズでの結果を得るという、分割統治によるアルゴリズムです。この記事でもこれをひたすら突き詰めていきます。

ほかにも、次数が素数であるときに適用できるRader FFTや、互いに素な2数の積であるときに適用できるPrime Factor FFTなどもあるらしいです。そのうち調べてみたいです。

Cooley-Tukey FFT

前の節で述べたように、分割統治によって次数を下げていくことで計算量を落とします。次数は合成数でさえあれば適用できますが、多くの場合2の冪乗で揃えます。以下、 N = 2 ^ l( lは正の整数)とします。

DFTの式について、 tを偶奇で分割してみます。

 \sum _ {t = 0} ^ {N-1} {f(t)W _ {N} ^ {tω}} = \sum _ {t=0} ^ {\frac{N}{2}-1} {f(2t) W _ {N} ^ {ω(2t)}} + \sum _ {t = 0} ^ {\frac{N}{2}-1} {f(2t+1) W _{N} ^ {ω(2t+1)}}

ところで、 pq = Nであるとき、 W _ {N} ^ {p} = W _ {q}となります。もともとの定義に立ち返れば、 W _ {N} ^ {p} = e ^ {- \frac {2 \pi p} {N}i} = e ^ {- \frac {2 \pi} {q}i} = W _ {q}となるのが確かめられます。

これを用いて上の式をさらに変形することで、

 \sum _ {t = 0} ^ {N-1} {f(t)W _ {N} ^ {tω}} = \sum _ {t=0} ^ {\frac{N}{2}-1} {g(t) W _ {\frac{N}{2}} ^ {tω}} + W _ {N} ^ {ω} \sum _ {t = 0} ^ {\frac{N}{2}-1} {h(t) W _{\frac{N}{2}} ^ {tω}}

となります(ただし、 g(t), h(t)はそれぞれ f(t)の偶数番目、奇数番目を集めた列)。

右辺をよくみると、 g(t), h(t)に対するDFTを行い、 h(t)への適用結果に回転因子を乗じて足す操作になっています。

分解は logN段で終わります。最下段では次数1のDFTが実行されますが、これは何もしないことと同じなので、 O(1)です。結果のマージは各段 O(N)で可能なので、結果的に計算量を O(NlogN)に削減できたことになります。

素朴な実装としては、以下のようになります。愚直なDFTの実行結果と比較していますが、浮動小数点数の比較なので、差が一定値以下のときパスとしています。

Rust Playground

 ω \ge \frac{N}{2}のときはどうやってマージするの?と一瞬迷いますが、よくよく考えると、 W _ {\frac{N}{2}} \frac{N}{2}乗したところで 1に戻りますから、 ω - \frac{N}{2}のときと一致します。よって、 ω - \frac{N}{2}の結果を再利用できます。 W _ {N}についても、 W _ {N} ^ {k + \frac{N}{2}} = -W _ {N} ^ {k}ですので、同様に ω - \frac{N}{2}の結果を利用できます。

NTTの話

これまで複素数でDFTを計算してきましたが、浮動小数点数の演算が必須のため、誤差が出ます。誤差が出るのは嬉しくないのでどうにかしたいですが、どうやらDFTは、1のN乗根が存在するような数のグループであれば適用できるようです。

そこで突然ですが、正の整数を素数 pで割った余りで構成される数のグループに適用することを考えます。これがNTT(Number Theoretic Transform, 数論変換)です。

このようなグループは必ず、 p-1乗して初めて 1と合同になる数( pの原始根)を持つので、これを使えます。

しかし素数なら何でもよいわけではなくて、FFTを適用するには 2 ^ l乗して 1になる数が必要なので、 p-1の素因数になるべく多く 2を含むような素数が望ましいです。このような素数をNTT Friendly素数と呼び、競プロでよく使われる998244353などが該当します。

上で述べたFFTアルゴリズムをそのまま用いてNTTを実装すると、以下のようになります。逆変換であるINTTも実装して、愚直な畳み込み演算と結果を比較しています。
Rust Playground

再帰

よく知られているように再帰関数は遅いことが多いです。これまでNTTは再帰関数として実装していたので、非再帰化します。

頑張って再帰関数の構造を紐解いて無理やり非再帰化しても良いですが、FFTのシグナルフローダイアグラムと呼ばれるものを用いれば、容易に非再帰のルーチンを導くことができます。フローダイアグラムが蝶の羽のような形をしているので、導かれたルーチンはバタフライ演算と呼ばれます。

フローダイアグラムはググればたくさん出てきますが、フローの最初と最後のどちらが大きいバタフライかで2種類のバリエーションがあるはずです。
上で導いたFFTアルゴリズム tを偶奇で分割する方法でしたが、これは時間間引き(Decimation In Time, DIT)というバリエーションにあたり、対応するフローダイアグラムは最初のバタフライが小さいほうです。

フローを流す前のデータの添え字を見てみるとヘンテコな並びをしていますが、これはビット反転順序という並びになっています。たとえば N = 8であれば、

 (0, 1, 2, 3, 4, 5, 6, 7) \\→ (000, 001, 010, 011, 100, 101, 110, 111) \\→ (000, 100, 010, 110, 001, 101, 011, 111) \\→ (0, 4, 2, 6, 1, 5, 3, 7)

というわけです。なので、正しい結果を得るためには、ルーチンの前にビット反転並べ替えをする必要があります。

実装は以下のような形になります。
Rust Playground

再帰関数で実装していた際には、データをインデックスの偶奇で分割するために作業用のメモリを必要としていましたが、この実装では元のデータ列をそのまま書き換えて結果を得ることができます。このような演算はin-place演算と言われます。
再帰実装の際はin-placeでない代わりに並べ替えは必要としませんでした。このように、作業用メモリを必要とする代わりに自然な並びで結果を出力するアルゴリズムとしてStockham FFTというものがあるようです。名前を聞いたことがあるくらいで詳しくは知らないので、これもそのうちちゃんと調べたいです。

ビット反転並べ替えのためにだいぶ非効率なことをしていますが、これは簡単に書ける方法をとったに過ぎません。実際にはより効率的な方法があるので、それを調べて使うのが良いでしょう。
大浦さんという方のFFTの解説ページに載っていたものが、今のところ試した中では最も速かったです。「大浦 FFT」などとググれば出てきます。

時間間引き・周波数間引き

先ほど「時間間引き」というCooley Tukey FFTのバリエーションが出てきましたが、これは tがDFTにおける時間変数を表し、これを分割して導出するためにそのような名前がついています。

DFTによって変換された関数は周波数関数になりますが、周波数変数 ωによって分割するバリエーションもあります。これが周波数間引き(Decimation In Frequency, DIF)であり、フローダイアグラムはフローの最初のほうが大きなバタフライになります。

こちらはフローを流しきった後のデータがビット反転順序となるので、ルーチンの最後でビット反転並べ替えをすればよいです。

これを実装した結果が以下です。時間間引き・周波数間引きのNTTの結果を比較して、一致することを確かめています。
Rust Playground

ちなみに、時間間引き・周波数間引きのルーチンは、互いに転置写像の関係にあり、機械的に導く方法があるようです。もっとちゃんと調べてみたいです。
(ネタ元:FFT の転置写像 – 37zigenのHP)

ビット反転並べ替えの削除

これまでみてきたように、DITはルーチンの冒頭、DIFはルーチンの末尾でビット反転並べ替えが必要でした。

ところで、畳み込み演算の用途であれば、DFTとIDFTは常に対で使います。また、ビット反転並べ替えは2回行えば元通りです。では、DFTにはDIFのルーチンを、IDFTにはDITのルーチンを用いれば、ビット反転並べ替えを削除できそうではないでしょうか?

これは実際、正しい結果を得ることができます。実装は以下の通りです。愚直な畳み込みの結果と比較しています。
Rust Playground

基数4のアルゴリズム

Twitterでご指摘いただきましたが、以下の乗算回数の考察には誤りがあります。
複素数でのFFTの際には W _ {4}の乗算が実部・虚部の入れ替えに相当するため、乗算が減ります。しかしNTTの際には W _ {4}の乗算は必要です。
下記のコード例でいうところのroot, root2, root3の乗算を含めて考察できていないのですが、これを含めると基数4のアルゴリズムであっても乗算の回数は減っていません。
従って、最外側のループの段数が減ること以上の効果はないかもしれません。
一応元の記載も残しておきますが、NTTの実装の改善には役立たない可能性があります。ごめんなさい。
以下、元の記載------------------

ルーチンの最内側のループの内部に注目すると、これはサイズ2のNTTです。これをもっと大きなサイズにできたら速くなりそうじゃないでしょうか?

サイズ2のNTTでは乗算1回、加算2回が必要ですから、サイズ4のNTTをサイズ2に分割して実行する場合は乗算4回、加算8回が必要です。

一方、分割しない場合は乗算1回、加算8回でできます。これはDFTの変換行列を眺めるとわかります。


\begin{pmatrix}
W _ {4} ^ {0} & W _ {4} ^ {0} & W _ {4} ^ {0} & W _ {4} ^ {0} \\
W _ {4} ^ {0} & W _ {4} ^ {1} & W _ {4} ^ {2} & W _ {4} ^ {3} \\
W _ {4} ^ {0} & W _ {4} ^ {2} & W _ {4} ^ {4} & W _ {4} ^ {6} \\
W _ {4} ^ {0} & W _ {4} ^ {3} & W _ {4} ^ {6} & W _ {4} ^ {9}
\end{pmatrix}
=
\begin{pmatrix}
W _ {4} ^ {0} & W _ {4} ^ {0} & W _ {4} ^ {0} & W _ {4} ^ {0} \\
W _ {4} ^ {0} & W _ {4} ^ {1} & W _ {4} ^ {2} & W _ {4} ^ {3} \\
W _ {4} ^ {0} & W _ {4} ^ {2} & W _ {4} ^ {0} & W _ {4} ^ {2} \\
W _ {4} ^ {0} & W _ {4} ^ {3} & W _ {4} ^ {2} & W _ {4} ^ {1}
\end{pmatrix}
=
\begin{pmatrix}
1 & 1 & 1 & 1 \\
1 & W _ {4} & -1 & -W _ {4} \\
1 & -1 & 1 & -1 \\
1 & -W _ {4} & -1 & W _ {4}
\end{pmatrix}

元のデータ列が (x _ {0}, x _ {1}, x _ {2}, x _ {3})であれば、 x _ {0} + x _ {2}, x _ {0} - x _ {2}, x _ {1} + x _ {3}, (x _ {1} - x _ {3})W _ {4}を前計算してから4回加減算すればよいので、結果、乗算1回、加算8回となります。

ルーチンの最内側のNTTの次数によって基数Nのアルゴリズムなどと言ったりするので、これは基数4のアルゴリズム、今までの実装は基数2のアルゴリズムとなります。

基数は大きくとるほど乗算の回数を減らすことが可能ですが、1度にロードしないといけない配列要素が増えるためキャッシュ効率的にデメリットが出てきたり、実装がやたら複雑な割に効果が薄くなったりしがちなので、基数4のアルゴリズムを使うことが多いようです。

また、基数2,4のアルゴリズムを組み合わせてフローダイアグラムを変則的に進めることで乗算回数を最適化するSplit-radix FFTというアルゴリズムもあります。残念ながら私の力では実装できなかったので、興味ある方は自力で調べてみてください。

以下、基数4のアルゴリズムの実装です。
Rust Playground

データ長が 4^lで表せない場合(つまり N = 2 ^ {2l+1}の形になる場合)、基数2のNTTで帳尻合わせをする必要があります。データ長を調整しても良いですが、メモリ効率的に嬉しくないですし、1のN乗根が取れない長さになるとそもそも正しい結果が得られません。

ルーチンの最内側の形をみると、NTTのほうはDFT行列を単純に乗じた形ではないし、INTTのほうはrootを変なかけ方をしているしで、微妙に頭が混乱しますが、基数2のアルゴリズムを力ずくで頑張って紐解くと、この形にできます。

回転因子をキャッシュする

ルーチンの中で使用する回転因子は、データ長 Nに対して N個しかありません。なので、キャッシュしておけば回転因子を求めるための掛け算を減らすことができます。

頑張ってやるだけなので、実装は省略します。

回転因子が 1であるときの乗算を削除する

見ればわかることですが、回転因子は最初は必ず 1です。
 1を掛け算するのは無駄なので、削除します。

ちゃんと数えてはいないですが、フローダイアグラムをみても 1の掛け算はそれなりの数がありますから、効果が期待できます。
 -1であるときも削除しても良いかもしれません。

やるだけなので、実装は省略します。以後の実装例でも無駄に実装が複雑になるので、1や-1の乗算を場合分けして削除することはしていません。

フローダイアグラムの添え字をすべてビット反転順序にする

先ほどビット反転並べ替えの文脈で登場した大浦さんの解説サイトで紹介されていた方法で、AtCoder Libraryで採用されている方法でもあります。

この方法では回転因子の指数の順序もビット反転順序になりますが、その代わり最内側のループではずっと同じ回転因子を乗じることになるため、最内側の回転因子の更新のための乗算がごっそりなくなります。

では、外側のループではどうやって回転因子を更新すればよいでしょうか?

大浦さんのサイトでは、単純に今欲しい指数のビット反転を行って回転因子を計算していますが、AtCoder Libraryでは(少なくとも私からすると)非常に天才的な方法で計算します。

AtCoder Libraryの実装を見ると、fft_info()という構造体があり、その中でrate2, rate3, irate2, irate3という配列を定義しているのがわかります。それぞれ、基数2のNTT、基数4のNTT、基数2のINTT、基数4のINTTで利用します。

例えばrate2の要素は、以下のような値となっています。

 rate2[i] = W _ {2 ^ 2} ^ {-1}W _ {2 ^ 3} ^ {-1}W _ {2 ^ 4} ^ {-1}...W _ {2 ^ {i+2}}

これを使って、例えばサイズ8のNTTの1段目に必要な回転因子を計算するとします。その場合、 W _ {8} ^ {0}, W _ {8} ^ {2}, W _ {8} ^ {1}, W _ {8} ^ {3}が欲しいですが、以下のような手続きでこれを得られます。

  • 初期値 w 1 (= W _ {8} ^ {0})とし、 0番目の回転因子とする
  • インデックス i 0とする
  •  w = w * rate2[i.trailing\_ones()]で更新する
  • 更新後の wは、 i+1番目の回転因子である
  •  i 2まで増やしながら上記の手続きを繰り返し、 0から 3番目の回転因子を順に得る

実際、

 i = 0, w = 1 \\
i = 1, w = 1 * rate[0] = W _ {4} = W _ {8} ^ {2} \\
i = 2, w = W _ {8} ^ {2} * rate[1] = W _ {8} ^ {2} * W _ {4} ^ {-1} * W _ {8} = W _ {8} ^ {2} * W _ {8} ^ {-2} * W _ {8} = W _ {8} ^ {1} \\
i = 3, w = W _ {8} * rate[0] = W _ {8} * W _ {4} = W _ {8} * W _ {8} ^ {2} = W _ {8} ^ {3}

となっており、欲しい回転因子の列が得られます。
凄いですね。競プロの典型か何かなのでしょうか?全く理解不能です…

実装例は以下のようになります…といっても、AtCoder Libraryと同じ実装なので、AtCoder Libraryを見ると良いでしょう。
Rust Playground

Rustに限らずC++でも、コンパイラオプションなどで拡張しない限り、コンパイル時計算できる演算量は限られています。しかしこの手法ではキャッシュがたかだか数十要素なので、コンパイル時計算も容易にできます。

多項式乗算の実装ではキャッシュを何度も使いまわしたい場面が多いでしょうから、とても嬉しいですね。
ただ、現状のAtCoderのRustバージョンでは、const fnの中でif分岐、whileループすらできないため、数十要素のキャッシュも簡単ではありませんが…

また、この手法ではフローダイアグラムの形が変わり、DIT、DIFともに向きが逆転します。ビット反転並べ替えが必要なパートも変わるため、NTTにはDITを、INTTにはDIFを使う必要があります。

ところで、フローダイアグラムの最初が大きなバタフライである(素朴なアルゴリズムではDIFに適用する)バタフライ演算を、Gentleman Sandeバラフライ、そうでないほうをCooley Tukeyバタフライと呼びます。
この名前は、バタフライの形に紐づいているのか、それともDIT、DIFに紐づいているのか、どちらなのでしょうか?いまだに分からないままです。
この節のアルゴリズムでは、フローダイアグラムの形とDIT・DIFの紐づきが変わっているので、どちらのバタフライ演算をCooley Tukey, Gentleman Sandeと呼べばいいのか分からないのです…ご存じの方、教えてください。

Six-step FFT

ルーチンの核となるループに注目すると、特に大きいサイズのNTTでは、かなり離れた位置にある要素を何度も取得しなければならない場面があり、キャッシュ効率的に嬉しくなさそうです。
キャッシュ効率を向上させるアルゴリズムとして、Six-step FFTというものがあります。この導出のためには、以下のようにDFTの式を変形します。

 F(kN _ {1} + l) \\
= \sum _ {pN _ {0} + q = 0} ^ {N - 1} {f(pN _ {0} + q)W _ {N} ^ {(kN _ {1} + l)(pN _ {0} + q)}} \\
= \sum _ {q = 0} ^ {N _ {0} - 1} {\{(\sum _ {p = 0} ^ {N _ {1} - 1} {f (pN _ {0} + q)W _ {N _ {1}} ^ {lp}})W _ {N} ^ {lq}\}W _ {N _ {0}} ^ {kq}}

ただし、 N = N _ {0}N _ {1}です。
入力データ列を N _ {0} N _ {1}列の行列として解釈すると、この式は以下のような手続きによる計算を意味します。

  • データ行列を転置する
  • 各行 N _ {1}サイズのFFTを実行する
  •  W _ {N} ^ {lq}をすべての要素に乗じる
  • データ行列を転置する
  • 各行 N _ {0}サイズのFFTを実行する
  • データ行列を転置する

 N _ {0}, N _ {1} \sqrt {N}に近くなるようにとれば、FFTの実行サイズは非常に小さくなり、キャッシュ効率が向上することが見込めます。
 N = 2 ^ {l}なので、一般に N _ {0} \le N _ {1}とすれば、 N _ {0} = N _ {1} 2N _ {0} = N _ {1}かの2通りがあります。

前者の場合は正方行列ですから、容易にin-placeでの行列転置が可能です。
後者の場合は正方行列ではなく、一般的に正方行列でない行列のin-place転置は容易ではありません。色々悩んだのですが、私は結局以下の資料からアルゴリズムを拝借しました。
[2011.11524] Speeding up decimal multiplication
要するには、 N \times 2N行列の場合、 N \times Nの部分行列2つを転置し、行を並び替えることで全体として転置できるというわけです。正方行列については上記の通りin-place転置可能、行の並び替えは O(N)サイズの作業メモリがあれば可能です。

また、FFTの結果はビット反転順序で返ってくるので、普通に実装すればFFTの実行ごとにビット反転並べ替えをする必要があります。
これを省くには、先に4ステップ目の行列転置を行い、 lをビット反転した指数を持つ回転因子を順に乗じればよいです。ここでも、前節の方法を応用できます。
ビット反転並べ替えを省いた場合、やはり全体としてビット反転順序で結果が返ってきます。

さらに、DFT・IDFTを対で使う場合は、それぞれ6ステップ目・1ステップ目の行列転置を省けます。転置の転置は元の行列なので、それは確かにそうですね。

さてここまで述べておいてなんですが、私が実装した結果では、Six-step FFTでの高速化は達成できませんでした。
真の理由はわかりませんが、高速な行列転置が実装できなかったことや、競プロで扱うようなデータサイズでは小さすぎてキャッシュ効率の向上度合いが大きくないというのが理由だと思っています。
2ステップ目、5ステップ目のFFTはそれぞれ、各行の中以外でのデータの依存関係はありませんから、並列実行が容易な気がします。もし並列実行が効果を発揮するような環境であれば、また結果は変わるのかもしれません。

実装は省略します。

ベクトル命令の使用

最終手段みたいなものですが、AVX2までのベクトル命令を用いた並列化を実装します。
なんでAVX512じゃないの?といえば、RustではstableでAVX512命令を使えないですし、そもそもAVX512をサポートするCPUを載せたPCやサーバなんて持っていないからです。エミュレータ的なものがあると聞きますが、コンテスト中にそんなの使うのは現実的ではないので、AVX2までで我慢します。

AVX2の場合、32bit整数8個の演算を同時に行えます。最内側のループのサイズが8以上であればレジスタへのロードは容易にできるので、効果は高そうです。

難点は、AVX系命令にもSSE系命令にも、整数除算・整数剰余算命令がないことです。
そこで、Montgomery剰余乗算を実装します。Montgomery剰余乗算のアルゴリズムを用いれば、加減乗算のみで乗算剰余を計算することが可能です。
Montgomery剰余乗算については既存の優れた記事がたくさんあるので、調べてみてください。(優れてはいないですが、私の既存記事もあります:Montgomery剰余乗算の学習メモ - 競プロ備忘録)

やるだけなので説明することもあまりありません。実装は以下の通りです。
Rust Playground
INTTでの \frac{1}{N}の乗算、データ列のマージも、 N \ge 8であれば、ベクトル化して高速化しています。入出力がある際は、Montgomery表現との間の変換・復元も8要素まとめて行うことで高速化できます。

上の実装はちょっと手抜きしていますが、本気でやれば N < 8の場合を除いて、最内側のループサイズが8未満の場合も含めてすべてのルーチンをベクトル化することができます。
この場合は単純に連続要素をロード・ストアすることができませんが、要素のロードはgather命令1つで可能です。計算後の要素のストアにはscatter命令を使いたくなりますが、AVX2までにはscatter命令が何故かありません。なので、shuffle, blend, pack, unpackなどを駆使して値を並べ替え、連続要素としてstoreする必要があります。
私のライブラリの実装ではそのあたりも頑張ってベクトル化しているので、暇な人は探してみてください。

あんまり適当な実装をし過ぎると、ベクトルレジスタが枯渇して意図せぬところにロード・ストアが挟まれ、速度が低下することがあります。その場合は計算の順序を入れ替えたり、ブロックを切ってライフタイムを制限したりといった工夫が必要になるかもしれません。

結果・まとめ

以上のような高速化の工夫を行った私のライブラリの実装で、冒頭に紹介したジャッジに投げてみたところ、最も高速なときで69msecとなりました。
C++の最速実装にはさすがに及びませんでしたが、Rustでの提出コード中では最速の実装となりました。爆速といえるかは微妙ですが、高速とはいってもよさそうでしょうか?
実際には入出力もそこそこ重いですから、その高速化が達成できないと、C++の実装には勝てない気がしています。

問題点としては、ちょっとコードが長すぎです。AtCoder Library Practice Contestにも全く同じ問題があるため、そちらのジャッジにも投げてみたのですが、cargo equipのminify機能を利用してもなお72000Byteです。こどふぉのジャッジではコード長制限オーバーです。
冗長なコードや無駄なコードもまだ多いので、コード削減は今後の課題になりそうです。

ちなみに、簡潔で高速なRust実装はyosupo judgeや上述のAtCoderのジャッジに存在していて、toomerさん(HBitさんと同一人物らしいです)の実装やtosさんの実装は、SIMDの実装がないにも関わらず非常に高速でかつコード長が短いです。
コンパクトなライブラリを志向する方は、そちらの実装を参考にするのが良いと思います。

理解しきれていないアルゴリズムもあります。「フローダイアグラムの添え字をすべてビット反転順序にする」の章については、フローダイアグラムの形から直感的には理解しているのですが、原理を論理的に示せてはいません。これからも既存の論文や記事を探すなどの努力が必要そうです。

参考にした資料

小野測器さんのコラムやる夫で学ぶディジタル信号処理FFT (高速フーリエ・コサイン・サイン変換) の概略と設計法OTFFT: High Speed FFT Library「Six-Step FFTとEight-Step FFTSpeeding up decimal multiplicationその他多数、参考にさせていただきました。