x64のAVXでマンデルブロ集合

Visual C++MASMマンデルブロ集合の計算を高速化。
SSE2でマンデルブロ集合 - merom686's blog
以前書いたSSE2
x64 アセンブリ言語プログラミング
64bit Windowsでの呼び出し規約

extern "C" void mandel_avx(__m256d *p);
__m256d v[6] = {x0_3, x4_7, y0_3, y4_7, out count0_7, m}; //擬似
mandel_avx(v);

このように使う。C++部分もx64で。

rem コマンドライン
ml64.exe /c /Cp /nologo /Fo"$(IntDir)%(Filename).obj" "%(Identity)"
rem 出力ファイル
$(IntDir)%(Filename).obj

拡張子.asmのファイルを追加し、カスタムビルドツールで上のように設定する。

for (int i = 0; i < m; i++){
    if (a * a + b * b > 4.0) break;
    t = a * a - b * b + x;
    b = 2.0 * a * b + y;
    a = t;
}

コメントにある a, b, x, y, m という記号は、上のような意味。

xmmA EQU xmm10
xmmB EQU xmm11
xmmC EQU xmm12
xmmD EQU xmm13
xmmE EQU xmm14
xmmF EQU xmm15
ymmA EQU ymm10
ymmB EQU ymm11
ymmC EQU ymm12
ymmD EQU ymm13
ymmE EQU ymm14
ymmF EQU ymm15

.data
align         16
D4  real8     4.0, 4.0, 4.0, 4.0

.code
mandel_avx proc

    sub       rsp, 168          ; この時点でrsp%16==8なので8byte余分に確保する
    vmovapd   [rsp    ], xmm6
    vmovapd   [rsp+ 16], xmm7
    vmovapd   [rsp+ 32], xmm8
    vmovapd   [rsp+ 48], xmm9
    vmovapd   [rsp+ 64], xmmA
    vmovapd   [rsp+ 80], xmmB
    vmovapd   [rsp+ 96], xmmC
    vmovapd   [rsp+112], xmmD   ; xmm6〜xmm15をスタックに保存
    vmovapd   [rsp+128], xmmE   ; 汎用レジスタは揮発性のものしか使っていないので保存不要
    vmovapd   [rsp+144], xmmF   ; #ほんとは例外処理のためにプロローグとかを

    vmovapd   ymm0, [rcx   ]    ; aをxで初期化
    vmovapd   ymm8, [rcx+32]
    vmovapd   ymm1, [rcx+64]    ; bをyで初期化
    vmovapd   ymm9, [rcx+96]
    mov       eax, [rcx+160]    ; m ゼロ拡張される
    mov       rdx, rax          ; 反復回数を保存しておく

    vmovapd   ymm6, ymm0        ; x
    vmovapd   ymmE, ymm8
    vmovapd   ymm7, ymm1        ; y
    vmovapd   ymmF, ymm9

    vpxor     xmm5, xmm5, xmm5  ; 4を超えたときのeaxを点ごとに保存する場所
    vpxor     xmmD, xmmD, xmmD
    vmovupd   ymm4, D4          ; 4.0 注:xmmCは別用途
    jmp       lp

    align     16                ; -:今使ってはいけない領域, ?:今後使われないデータ
lp:                             ; ymm = a b ? ? - - x y
    vmulpd    ymm3, ymm0, ymm1  ; ymm = a b ? ab - - x y
    vmulpd    ymmB, ymm8, ymm9
    vmulpd    ymm2, ymm0, ymm0  ; ymm = a b aa ab - - x y
    vmulpd    ymmA, ymm8, ymm8
    vmulpd    ymm1, ymm1, ymm1  ; ymm = a bb aa ab - - x y
    vmulpd    ymm9, ymm9, ymm9
    vaddpd    ymm3, ymm3, ymm3  ; ymm = a bb aa 2ab - - x y
    vaddpd    ymmB, ymmB, ymmB
    vaddpd    ymm0, ymm2, ymm6  ; ymm = aa+x bb aa 2ab - - x y
    vaddpd    ymm8, ymmA, ymmE
    vaddpd    ymm2, ymm2, ymm1  ; ymm = aa+x bb aa+bb 2ab - - x y
    vaddpd    ymmA, ymmA, ymm9
    vcmpnlepd ymm2, ymm2, ymm4  ; ymm = aa+x bb aa+bb>4 2ab 4 - x y
    vcmpnlepd ymmA, ymmA, ymm4
    vptest    ymm2, ymm2        ; 1個も4を超えなかったらZFが立つ
    jnz       diverged          ; ZFが立っていなかったらジャンプ
cont:
    vptest    ymmA, ymmA
    jnz       diverged2
cont2:
    vsubpd    ymm0, ymm0, ymm1  ; ymm = aa+x bb ? 2ab - - x y
    vsubpd    ymm8, ymm8, ymm9
    vaddpd    ymm1, ymm3, ymm7  ; ymm = aa-bb+x 2ab+y ? 2ab - - x y
    vaddpd    ymm9, ymmB, ymmF
    dec       rax
    jnz       lp                ; m回ループしたら終わり
    jmp       exit

diverged:
    vandnpd   ymm0, ymm2, ymm0  ; 4を超えたところだけビットが立っているのでandnを使う
    vandnpd   ymm1, ymm2, ymm1  ; 発散した点をゼロクリアし二度とここへ来ないようにする
    vandnpd   ymm3, ymm2, ymm3  ; ymm = aa+x bb ? 2ab - - x y //これらが対象
    vandnpd   ymm6, ymm2, ymm6
    vandnpd   ymm7, ymm2, ymm7
                                ; AVX2は使っていない
    vextractf128 xmmC, ymm2, 0  ; xmmC = d1 d0 (= i3 i2 i1 i0)
    vpshufd      xmmC, xmmC, 8  ; xmmC = i0 i0 i2 i0 //上位2つは何でもいい
    vextractf128 xmm2, ymm2, 1  ; xmm2 = d3 d2 (= i7 i6 i5 i4)
    vpshufd      xmm2, xmm2, 8  ; xmm2 = i4 i4 i6 i4
    vpunpcklqdq  xmm2, xmmC, xmm2;xmm2 = i6 i4 i2 i0 (= d3l d2l d1l d0l) //d0lはd0の下位64bit

    vmovq     xmmC, rax         ; ymm = - - mask - rax count - -
    vpshufd   xmmC, xmmC, 0     ; xmmC = eax eax eax eax
    vpand     xmmC, xmmC, xmm2  ; 反復回数をマスク、4を超えたところだけ残す
    vpor      xmm5, xmm5, xmmC  ; 発散した点へ回数を書き込む

    vpxor     xmmC, xmmC, xmmC  ; 比較用にゼロを作る
    vpcmpeqd  xmm2, xmmD, xmmC  ; xmm2 = (xmmD == 0) ? -1 : 0
    vptest    xmm2, xmm2
    jnz       cont              ; xmmDに0が一つでもあったらループに戻る

    vpcmpeqd  xmm2, xmm5, xmmC  ; xmm2 = (xmm5 == 0) ? -1 : 0
    vptest    xmm2, xmm2
    jnz       cont2             ; xmmD側が完了していればcont2へ戻ってもよい
    jmp       exit              ; 全ての点の計算が完了してたら終わり

diverged2:
    vandnpd   ymm8, ymmA, ymm8  ; 上と同じ
    vandnpd   ymm9, ymmA, ymm9
    vandnpd   ymmB, ymmA, ymmB
    vandnpd   ymmE, ymmA, ymmE
    vandnpd   ymmF, ymmA, ymmF

    vextractf128 xmmC, ymmA, 0  ; 上でもxmmCを使っている
    vpshufd      xmmC, xmmC, 8  ; 対となるymm4は、比較用の4.0
    vextractf128 xmmA, ymmA, 1
    vpshufd      xmmA, xmmA, 8
    vpunpcklqdq  xmmA, xmmC, xmmA

    vmovq     xmmC, rax
    vpshufd   xmmC, xmmC, 0
    vpand     xmmC, xmmC, xmmA
    vpor      xmmD, xmmD, xmmC

    vpxor     xmmC, xmmC, xmmC
    vpcmpeqd  xmmA, xmmD, xmmC
    vptest    xmmA, xmmA
    jnz       cont2

    vpcmpeqd  xmmA, xmm5, xmmC
    vptest    xmmA, xmmA
    jnz       cont2             ; 終わりなら、そのままexit:へ

exit:
    vmovq     xmmC, rdx         ; xmmC = rdx = m
    vpshufd   xmmC, xmmC, 0     ; xmmC = m m m m
    vpsubd    xmm5, xmmC, xmm5
    vpsubd    xmmD, xmmC, xmmD  ; xmm5とxmmDに各点の反復回数(32bit)が入る
    vmovdqa   xmmword ptr [rcx+128], xmm5
    vmovdqa   xmmword ptr [rcx+144], xmmD

    vmovapd   xmm6, [rsp    ]
    vmovapd   xmm7, [rsp+ 16]
    vmovapd   xmm8, [rsp+ 32]
    vmovapd   xmm9, [rsp+ 48]
    vmovapd   xmmA, [rsp+ 64]
    vmovapd   xmmB, [rsp+ 80]
    vmovapd   xmmC, [rsp+ 96]
    vmovapd   xmmD, [rsp+112]
    vmovapd   xmmE, [rsp+128]
    vmovapd   xmmF, [rsp+144]
    add       rsp, 168
    vzeroupper                  ; SSEを使えるように依存関係を断ち切っておく
    ret

mandel_avx endp

end

ymm9とymm10で文字数が変わるのが嫌なので、ymm10をymmAと書けるように設定。

ymm_k と ymm_k+8 の2系統を同時に処理することでレイテンシを隠蔽している。これは、x64でレジスタが倍増したから可能になった。少し高速化。

SSE2版と比べ、AVXで処理能力が2倍になったことはもちろんのこと、3オペランドになったことでも少し高速化している(無駄なmovapdが減った)。