えびちゃんの日記

えびちゃん(競プロ)の日記です。

ソースコードを見て計算量を下から抑えるのは怪しいという話

競プロ er はよく計算量の見積もりをします。「これこれの計算量は $O(\dots)$ なので十分高速である」といった具合で上から抑えることが多いです。 また、「これこれの計算量は $\Omega(\dots)$ なので TLE しそう」といった具合で下から抑えることもしばしばあります。

note:「これこれの計算量は $O(2^{2^n})$ なので TLE しそう」といった記号の使い方($O$ で下から抑えようとする)は、不正確な用法なので気をつけましょう。知らずに使っていた人はちゃんと勉強しましょう。

「下から抑える」について

下から抑えるというのは、見積もりたい値はこれ以上であるという値(下界かかい と呼ばれます)を求めるという意味の言い回しです。 ある $a$ を使って $a\le x$ と書けたら「$x$ は $a$ で下から抑えられる」と言います。 逆に、$x\le b$ は「$x$ は $b$ で上から抑えられる」と言います。

「上(下)から抑える」を「上(下)から評価する」と言ったり、それを求めることを「上(下)からの評価」と呼んだりします。

$O$ は定義から上から抑えるための記法(計算量は $O(f(n))$ ですと言ったら、計算量は $f(n)$ のオーダー以下であるということ)なので、下からの評価をしたい文脈とは相性が悪いです。

今日は、ソースコードを見て「この計算量は何々だから TLE するでしょ」という決めつけが必ずしも正しくないですという話をします。「実際に計算量は $\Theta(n^2)$ だけどアクセスの効率がよくて定数倍がめちゃ小さいので AC できる」という話はしません。

!! おまけ (2) が一番びっくりかもしれません。!!

問題提起

さて、次の C++ コードの計算量はどうなるでしょうか。上からも下からもしっかり見積もるべく、$O$ や $\Omega$ ではなく $\Theta$ を使います。

int sum_n(int n) {
    int res = 0;
    for (int i = 0; i <= n; ++i) res += i;
    return res;
}

disclaimer: 最近話題のツイートに起因して書いているものではありません*1。

多くの人は $\Theta(n)$ と思うのではないでしょうか。ループ中で $n+1$ 回の res += i を行いそうなためです(もちろん i <= n や ++i も考慮は必要です)。 実際 GCC では $\Theta(n)$ 時間ですが、Clang では $\Theta(1)$ 時間となります(どちらも -O2 での最適化は前提としています)。

コンパイラがこのような最適化を行ってくれるケースがあるというのを覚えておくべきでしょう。

解説

機械が実際に実行するのは C++ のソースコードではなく機械語ですから、それと対応しているアセンブリを読んでみましょう。これは上記の GCC や Clang が生成したものです。 手元の環境で様々なコンパイラを用意するのは面倒なので、そういうサービスを使います。

godbolt.org

画面左のウィンドウに先ほどのソースコードを入力し、画面上部の設定には下記を指定します(一度に指定できるのは一つずつのみです)。

言語 コンパイラ オプション
C++ x86-64 gcc 13.2 -O2
C++ x86-64 clang 16.0 -O2

画面右のウィンドウで出力のアセンブリが見られるので、これを見ていきましょう。 アセンブリの読み方に関して、説明を丁寧に書こうかと思ったのですが、途中で面倒になったのでやめてしまいました。記事の末尾に参考になりそうなものを挙げておくので各自勉強してください。

ここでは、上記で得られたアセンブリを読める程度の簡単な説明だけをして済ませることにします。

最初の引数が edi と呼ばれるレジスタ*2に入ります。rdi と edi は下位 32 bits を共有していて、rdi は 64 bits、edi は 32 bits です(下図のイメージ)。返り値は eax と呼ばれるレジスタに入れます。

[ ---------------- ---------------- ]  # <- rdi
                   ^^^^^^^^^^^^^^^^    # <- edi

各行は op arg, ... のような形式をしています(; 以降はコメントです。処理系によっては # だったりするようです)。op が命令の名前、arg が引数です。 命令が実行されるたびに、フラグレジスタと呼ばれるレジスタの値が変わったり変わらなかったりします。 フラグレジスタには、計算結果が 0 だったかどうかとか、オーバーフローしたかどうかとか、負だったか(符号ビットが立っているか)どうかとかの情報が入っています。 test x y という命令は x & y を計算します。js .label の命令ではフラグレジスタの状態(たとえば x & y の計算結果による)によって .label と書かれた位置にジャンプしたりしなかったりします。jxx の xx の部分がフラグレジスタのどのフラグを参照するかに対応します。js では SF(符号フラグ、負だったときに true)を見ます。

xor mov add などの命令は名前から想像できるような処理をします。記法にはいくつか流派があるのですが、ここでは計算結果は左側の引数に入ります。 たとえば add eax 2 であれば eax += 2 のようなものに相当します。

lea はおそらく元々はアドレス計算に関する命令なのですが、何かと都合がよい ので、加算や乗算をしたいときにしばしば登場します。どのような計算をしているかについてはコード中のコメントを参照してください。

コメントを添えていきます。関数に渡された時点での引数を n と置いておきます。 まずは GCC です。edi と rdi は、下位 32 bits を共有している(暗黙に同期されている)特殊な変数であるかのようなイメージで読んでください。eax と rax、edx と rdx などについても同様です。

sum(int):                               ; int sum(int edi) {
        test    edi, edi                ;     if ((edi & edi) < 0)
        js      .L4                     ;         goto L4;
        lea     ecx, [rdi+1]            ;     ecx = rdi + 1;
        xor     eax, eax                ;     eax ^= eax;
        xor     edx, edx                ;     edx ^= edx;
        and     edi, 1                  ;     edi &= 1;
        jne     .L3                     ;     if (edi != 0) goto L3;
        mov     eax, 1                  ;     eax = 1;
        cmp     eax, ecx                ;     if (eax == ecx)
        je      .L1                     ;         goto L1;
.L3:                                    ; L3:
        lea     edx, [rdx+1+rax*2]      ;     edx = rdx + 1 + rax*2;
        add     eax, 2                  ;     eax += 2;
        cmp     eax, ecx                ;     if (eax != ecx)
        jne     .L3                     ;         goto L3;
.L1:                                    ; L1:
        mov     eax, edx                ;     eax = edx;
        ret                             ;     return eax;
.L4:                                    ; L4:
        xor     edx, edx                ;     edx ^= edx;
        mov     eax, edx                ;     eax ^= edx;
        ret                             ;     return eax;
                                        ; }

各レジスタで計算しているものの意図を汲んだような名前をつけてコメントを添えると、次のような感じになります。

int sum(int edi) {
    if ((edi & edi) < 0)
        goto L4;  // if (n < 0) goto L4;
    ecx = rdi + 1;  // limit = n + 1;
    eax ^= eax;  // i = 0;
    edx ^= edx;  // res = 0;
    edi &= 1;
    if (edi != 0) goto L3; // if (n % 2 == 1) goto L3;
    eax = 1;  // i = 1;
    if (eax == ecx)
        goto L1;  // if (i == limit) goto L1;
L3:
    edx = rdx + 1 + rax*2;  // res += 1 + 2 * i
    eax += 2;  // i += 2;
    if (eax != ecx)
        goto L3;  // if (i != limit) goto L3;
L1:
    eax = edx;
    return eax;  // return res;
L4:
    edx ^= edx;
    eax ^= edx;
    return eax;  // return res;
}

境界値がややこしいですが、n が偶数なら (1+2)+(3+4)+...、奇数なら 1+(2+3)+(4+5)+... のように隣り合う要素をまとめて足していくような最適化をしています。 とはいえ、計算量は $\Theta(n)$ です。

次は Clang です。

sum(int):                               ; int sum(int edi) {
        mov     eax, edi                ;     eax = edi;
        lea     ecx, [rdi - 1]          ;     ecx = rdi - 1;
        imul    rcx, rax                ;     rcx *= rax;
        shr     rcx                     ;     rcx >>= 1;
        add     ecx, edi                ;     ecx += edi;
        xor     eax, eax                ;     eax ^= eax;
        test    edi, edi                ;     if ((edi & edi) >= 0)
        cmovns  eax, ecx                ;         eax = ecx;
        ret                             ;     return eax;
                                        ; }

こちらも意図を汲むと次のような感じです。

int sum(int edi) {
    eax = edi;  // tmp = n;
    ecx = rdi - 1;  // sum = n - 1; 
    rcx *= rax;  // sum *= tmp; // i.e. sum *= n
    rcx >>= 1;  // sum /= 2;  // sum == n * (n - 1) / 2;
    ecx += edi;  // sum += n;  // sum == n * (n + 1) / 2;
    eax ^= eax;  // res = 0;
    if ((edi & edi) >= 0)
        eax = ecx;  // if (n >= 0) res = sum;
    return eax;  // return res;
}

$n\ge 0\implies \sum_{i=0}^n i = n(n+1)/2$ を用いて $\Theta(1)$ の処理に最適化されています。

おまけ

Rust でもいろいろ遊べるので遊んでみます。

pub fn sum(n: u32) -> u32 {
    (0..=n).sum()
}
pub fn sum_128(n: u128) -> u128 {
    (0..=n).sum()
}

上記の関数を見てみます。pub にする必要があることに注意してください。つけ忘れると

<No assembly to display (~5 lines filtered)>

のような表示が出ます。オプションは -C opt-level=3 などにしておきます。

次のような感じです。適宜読んでください。

example::sum:                           ; fn sum(edi: u32) -> u32 {
        test    edi, edi                ;     if edi & edi == 0 {
        je      .LBB0_1                 ;         goto 'LBB0_1; }
        lea     eax, [rdi - 1]          ;     eax = rdi - 1;
        lea     ecx, [rdi - 2]          ;     ecx = rdi - 2;
        imul    rcx, rax                ;     rcx *= rax;
        shr     rcx                     ;     rcx >>= 1;  // (n - 1) * (n - 2) / 2
        lea     eax, [rdi + rcx]        ;     eax = rdi + rcx;
        dec     eax                     ;     eax -= 1;
        add     eax, edi                ;     eax += edi;  // (n - 1) * (n - 2) / 2 + n - 1 + n
        ret                             ;     return eax;  // == (n - 1) * n / 2 + n == (n + 1) * n / 2
.LBB0_1:                                ; 'LBB0_1:
        xor     eax, eax                ;     eax ^= eax;
        add     eax, edi                ;     eax += edi;
        ret                             ;     return eax;  // 0
                                        ; }

128-bit 整数の方は長いですが、128-bit 整数同士の演算自体にいくつかの命令が使われているだけで、やりたいこととしては大差ないでしょう。実はちゃんと読んでいません。

example::sum_128:                       ; fn sum_128(rdi:rsi: u128) -> u128 {
        mov     rax, rdi                ;     rax = rdi;
        or      rax, rsi                ;     rax |= rsi;
        je      .LBB2_1                 ;     if rax == 0 { goto 'LBB2_1; }
        push    r14                     ;     tmp_r14 = r14;
        push    rbx                     ;     tmp_rbx = rbx;
        mov     r8, rdi                 ;     r8 = rdi;
        add     r8, -1                  ;     (add, carry) = r8.carrying_add(18446744073709551615);
                                        ;     r8 = add;
        mov     r9, rsi                 ;     r9 = rsi;
        adc     r9, -1                  ;     r9 += 18446744073709551615 + carry 
        mov     rbx, rdi                ;     rbx = rdi;
        add     rbx, -2                 ;     (add, carry) = rbx.carrying_add(18446744073709551614);
                                        ;     rbx = add;
        mov     rcx, rsi                ;     rcx = rsi;
        adc     rcx, -1                 ;     rcx += 18446744073709551615 + carry;
        mov     rax, r9                 ;     rax = r9;
        mul     rbx                     ;     rax *= rbx;
        mov     r10, rdx                ;     r10 = rdx;
        mov     r11, rax                ;     r11 = rax;
        mov     rax, r8                 ;     rax = r8;
        mul     rbx                     ;     rax *= rbx;
        mov     rbx, rax                ;     rbx = rax;
        mov     r14, rdx                ;     r14 = rdx;
        add     r14, r11                ;     (add, carry) = r14.carrying_add(r11);
                                        ;     r14 = add;
        adc     r10d, 0                 ;     r10 += 0 + carry;
        mov     rax, r8                 ;     rax = r8;
        mul     rcx                     ;     rax *= rcx;
        add     rax, r14                ;     (add, carry) = rax .carrying_add(r14);
                                        ;     rax += carry;
        adc     edx, r10d               ;     edx += r10d + carry;
        imul    ecx, r9d                ;     ecx *= r9d;
        add     ecx, edx                ;     ecx += edx;
        shld    rcx, rax, 63            ;     rcx:rax >>= 63;
        shld    rax, rbx, 63            ;     rax:rbx >>= 63;
        add     rax, r8                 ;     (add, carry) = rax.carrying_add(r8);
                                        ; rax = add;
        adc     rcx, r9                 ;     rcx += r9 + carry;
        pop     rbx                     ;     rbx = tmp_rbx;
        pop     r14                     ;     r14 = tmp_r14;
        add     rax, rdi                ;     (add, carry) = rax.carrying_add(rdi);
                                        ;     rax = add;
        adc     rcx, rsi                ;     rcx += rsi + carry;
        mov     rdx, rcx                ;     rdx = rcx;
        ret                             ;     return rax:rdx;
.LBB2_1:                                ; 'LBB2_1:
        xor     eax, eax                ;     eax ^= eax;
        xor     ecx, ecx                ;     ecx ^= ecx;
        add     rax, rdi                ;     (add, carry) = rax.carrying_add(rdi);
                                        ;     rax = add;
        adc     rcx, rsi                ;     rcx += rsi + carry;
        mov     rdx, rcx                ;     rdx = rcx;
        ret                             ;     rax
                                        ; }

pub fn sum_3(n: u32) -> u32 { (0..=n).step_by(3).sum() } のようなものは $\Theta(1)$ にはなってくれませんでした。

おまけ (2)

たぶんここすごいです。

int square_sum(int n) {
    int res = 0;
    for (int i = 0; i <= n; ++i) res += i * i;
    return res;
}

これを Clang にやってもらいます。

square_sum(int):
        test    edi, edi
        js      .LBB0_1
        mov     eax, edi
        lea     ecx, [rdi - 1]
        imul    rcx, rax
        lea     eax, [rdi - 2]
        imul    rax, rcx
        shr     rax
        imul    edx, eax, 1431655766
        shr     rcx
        lea     eax, [rcx + 2*rcx]
        add     eax, edi
        add     eax, edx
        ret
.LBB0_1:
        xor     eax, eax
        ret

慣れていない人は 1431655766 ってな〜んだ?となりそうです。 ちゃんと C++ でコンパイルの通る形で書くと次のようなものになりそうです。

int square_sum(int n) {
  if (n < 0)
    goto LBB0_1;

  {
    unsigned edi = n;
    unsigned eax = edi;
    unsigned ecx = edi - 1;
    unsigned long rcx = (long)ecx * (long)eax;
    eax = edi - 2;
    ecx = rcx;
    unsigned long rax = (long)eax * (long)ecx;
    rax >>= 1;
    eax = rax;
    unsigned edx = eax * 1431655766u;
    rcx >>= 1;
    eax = 3 * rcx;
    eax += edi;
    eax += edx;
    return eax;
  }

LBB0_1:
  return 0;
}

int main() {
  for (int i = -10; i <= 10; ++i) {
    printf("%d%c", square_sum(i), i < 10 ? ' ' : '\n');
  }
  // 0 0 0 0 0 0 0 0 0 0 0 1 5 14 30 55 91 140 204 285 385
}

諸々を整理すると、考えるべきパートは概ね次のような感じです。

    unsigned edx = n * (n - 1) * (n - 2) / 2 * 1431655766u;
    unsigned ecx = n * (n - 1) / 2;
    return 3 * ecx + edx + n;

hint: 1431655766 == 0x55555556.

いくつか例を見てみましょう。32-bit 符号なし整数で考えて、オーバーフローは wrapping($2^{32}$ を法として考える)とします。

>>> (8000 * 1431655766) % (2**32)
2863316864
>>> (8001 * 1431655766) % (2**32)
5334
>>> (8002 * 1431655766) % (2**32)
1431661100

大胆予想です。 $$ x\equiv 0\pmod{3} \implies (x\times 1431655766)\bmod 2^{32} = \tfrac23 x. $$

種明かしというかなんというか、$1431655766 = (2^{32} + 2)/3$ です。 なので、$x = 3y$ とすると下記のようにできます。 $$ \begin{aligned} 3y\times ((2^{32} + 2)/3) &= y\times (2^{32}+2) \\\ &\equiv y\times 2 = 2y \pmod{2^{32}}. \end{aligned} $$ なるほど〜という感じです。

さて、これを踏まえて計算すれば、$\tfrac16 n(n+1)(2n+1)$ を求めていることがわかるでしょう($n(n-1)(n-2)/2$ は $3$ の倍数であることに注意)。 計算途中の各値は、必要に応じて 64-bit 整数を使いつつ求めているので、オーバーフローがあった場合も $\tfrac16 n(n+1)(2n+1)\bmod 2^{32}$ になっていそうです。

ところで、適切な範囲において、$\floor{(x\times 1431655766)/2^{32}} = \floor{x/3}$ のような話もありそうです。 $n$ を $2^{32}$ で割った整数部分というのは、n >> 32 だったり上位 dword を持ってきたりすることで高速に計算できますから、除算の高速化に貢献しそうです(実際、コンパイラはそうした類の最適化をしてくれます)。

おあそび

アセンブリを自分で書いて試せる状態になっているとお勉強が捗ると思うので、そういうことをしましょう。

↓ foo.s ↓

        .intel_syntax
        .file "foo.s"
        .text
        .globl foo
        .type foo, @function
foo:
        mov     %eax, %edi
        imul    %eax, %eax
        add     %eax, 2
        ret
        .section .note.GNU-stack,"",@progbits

↓ main.c ↓

#include <stdio.h>

int foo(int);

int main(void) {
    printf("%d\n", foo(5));
}

↓ コンパイル・実行 ↓

% as -o foo.o foo.s
% gcc foo.o main.c -o main
% ./main

あるいは、適当なプログラム prog.c を書いて gcc -S prog.c などをすると prog.s が得られるので、それを読むのもよいかもしれません。

また、M2 Mac などを使っている人は上記のアセンブリでは動かなさそう(怒られました)なので、別途考える必要があります。命令セットとかレジスタの名前とかが違いそうです。

↓ foo.s ↓

        .file   "foo.s"
        .text
        .global _foo
_foo:
        mov     w0, w0
        mul     w0, w0, w0
        add     w0, w0, #2
        ret
        .align  8

また、下記のようなことをすると楽しい気持ちになる人もいるかもしれません。

% objdump -D foo.o

関連資料

例によって日本語の資料はあまり探していません。よしなにしてくれたらうれしいです。

あわせて読みたい

読みたいかどうかは人によるというのはそう。

あとがき

込み入った解説を書かないとあっさりめな記事になるなあという気持ちです。アセンブリに関してなんかいろいろ書こうとしたんですが、「書いて誰が幸せになるんだろう」「各々が調べてくれたらいいや」という気分になって消してしまいました。文献は挙げたので意欲や興味がある人はがんばってほしいです。

それから、コンパイラがオーダーを落としてくれるようなケースは基本的には稀という気がしています。ただ稀だからといって無視していると足を掬われそうです。 未定義動作を利用されて、やばい最適化が起きて定数時間の処理になっていることはしばしばある気もします。

定数による除算なんかは(除算命令は重いので)除算を使わない形に書き換える最適化をしてくれがちです。こうした話に関してもアセンブリを読めると学習が捗るような気がします。

Clang が $\sum_{i=0}^n i^2$ についてもループなしにしてくれた上、除算の部分をよい感じに最適化してくれたので面白かったです。 $\floor{n/3}$ のようなタイプの最適化は知っていたのですが、$\tfrac23 n$ のようなタイプは知らなかったので勉強になりました。 さらに大きい $k$ に対して $\sum_{i=0}^n i^k$ を見てみても面白いかもしれませんね。

おわり

おわりです。

*1:こんな例は 5000 年前から 3 万回は見てきたため。実際、元々 3 週間くらい前に下書きをしていた記事です。

*2:レジスタのことは、edi や ecx など名前のついたメモリだと思ってもらえばいいと思います。どのようなレジスタがあるかは、関連資料に挙げたものを読んでもらえるとよさそうです。