下のコードは、SSE2でマンデルブロ集合の計算をするもの。かなり高速に動作すると思う。
このエントリでは、このコードの解説をするにょ。
#うるりの物置きのMandel100.zipにあるSSEのコードが元になっている。
; MASM 8.0 を使用 .686 .xmm .model flat, c .data align 16 D4 real8 4.0, 4.0 L1 oword -1 .code mandel_sse2 proc x:ptr real8, y:ptr real8, m:dword, count:ptr dword mov eax, x movapd xmm0, [eax] mov eax, y movapd xmm1, [eax] mov ecx, m movapd xmm6, xmm0 movapd xmm7, xmm1 pxor xmm5, xmm5 jmp brotloop align 16 brotloop: ; xmmx = a b ? ? temp count x y movapd xmm2, xmm0 ; xmm2 = a mulpd xmm0, xmm0 ; xmm0 = aa movapd xmm3, xmm1 ; xmm3 = b mulpd xmm1, xmm2 ; xmm1 = ab, free xmm2 movapd xmm2, xmm0 ; xmm2 = aa mulpd xmm3, xmm3 ; xmm3 = bb addpd xmm0, xmm6 ; xmm0 = aa+x addpd xmm2, xmm3 ; xmm2 = aa+bb = rr addpd xmm1, xmm1 ; xmm1 = 2ab cmpnlepd xmm2, D4 ; xmm2 = (rr > 4) ? -1 : 0 subpd xmm0, xmm3 ; xmm0 = aa-bb+x = a', free xmm3 movmskpd eax, xmm2 addpd xmm1, xmm7 ; xmm1 = 2ab+y = b' test eax, eax jnz escaped reentry: dec ecx jnz brotloop jmp exit escaped: movd xmm4, ecx pshufd xmm4, xmm4, 0 pand xmm4, xmm2 ; 反復回数をマスク por xmm5, xmm4 ; 発散した点へ回数を書き込む pxor xmm2, L1 ; ビット反転ってこれしか方法ないんだろうか andpd xmm0, xmm2 andpd xmm1, xmm2 andpd xmm6, xmm2 andpd xmm7, xmm2 ; 発散した点をクリアし二度とここに来ないようにする pxor xmm4, xmm4 pcmpeqd xmm4, xmm5 ; xmm4 = (xmm5 == 0) ? -1 : 0 pmovmskb eax, xmm4 test eax, eax ; xmm5に0が jnz reentry ; 一つでもあったらループに戻る exit: mov eax, m movd xmm4, eax pshufd xmm4, xmm4, 0 psubd xmm4, xmm5 mov eax, count movdqa [eax], xmm4 ret mandel_sse2 endp end
ここから、MASM用のソースコードを少しずつ読み進めていく。
; MASM 8.0 を使用 .686 .xmm .model flat, c
「;」は行コメント(行の途中からでも使える)。
.686と.xmmはアセンブル対象の命令を指示する。
SSEを使うので、.xmmを指定。これでSSE2の命令もアセンブルしてくれる。
.modelではメモリモデルと関数の呼出規約を指定。
呼出規約はC言語に合わせてcdeclとした。これなら面倒がない。
.data align 16 D4 real8 4.0, 4.0 L1 oword -1
.data。コード内で使う定数をメモリに置いてもらう。
SSE2のmovapdで読み込むためには、アドレス値が16の倍数である必要があるので、
align 16 としてデータの置き場所をそのように指定する。
D4が定数の名前、real8が型、4.0がメモリに置かれる値。
owordというのは16byteの型で、-1と書くことで128bit全てを1にしている。
.code
ここからがコード。
mandel_sse2 proc x:ptr real8, y:ptr real8, m:dword, count:ptr dword
//C++から呼び出すときはこんな感じ extern "C" void mandel_sse2(double *x, double *y, long m, long *c); int hoge() { double x[2], y[2]; long m, count[2]; … mandel_sse2(x, y, m, count); …
mandel_sse2という名前の関数が始まる。引数は4つ。x, y, count の3つはポインタ。
mov eax, x movapd xmm0, [eax] mov eax, y movapd xmm1, [eax] mov ecx, m movapd xmm6, xmm0 movapd xmm7, xmm1 pxor xmm5, xmm5
まず、ポインタxをeaxへコピーし、そのアドレスからdouble値2つをxmm0へロードする。
yも同様。mはecxへ読み込んでカウンタとして使う。
xmm0をxmm6へコピー、xmm1をxmm7へコピー、xmm5の値を0にする。
//C++版を書くとすればこんな感じ long mandel_cpp(double x, double y, long m) { long k, count; double a = x, b = y, t; for (k = 0; k < m; k++){ if (a * a + b * b > 4.0) break; t = a * a - b * b + x; b = 2.0 * a * b + y; a = t; } count = k; return count; }
ここからの処理内容をC++で書くと上のようになる。
MASM用コードのレジスタと、この変数たちとの対応を下の表で示す。
レジスタ | C++版の変数 |
---|---|
ecx(ダウンカウント) | k |
xmm0 | a |
xmm1 | b |
xmm5 | count |
xmm6 | x |
xmm7 | y |
jmp brotloop align 16
さて、次にループの開始を16byte境界にアラインさせておく。
CPUが少しでも命令を読みやすいように。
ここからが、速度的にクリティカルな部分だ。
brotloop: ; xmmx = a b ? ? temp count x y movapd xmm2, xmm0 ; xmm2 = a mulpd xmm0, xmm0 ; xmm0 = aa movapd xmm3, xmm1 ; xmm3 = b mulpd xmm1, xmm2 ; xmm1 = ab, free xmm2 movapd xmm2, xmm0 ; xmm2 = aa mulpd xmm3, xmm3 ; xmm3 = bb addpd xmm0, xmm6 ; xmm0 = aa+x addpd xmm2, xmm3 ; xmm2 = aa+bb = rr addpd xmm1, xmm1 ; xmm1 = 2ab
brotloop: はラベル。ループの開始地点。
複素数のa+biを2乗してx+yiを足す操作を繰り返すところ。
ここは全てSSE2の命令だが、加算・乗算・コピーだけなので意味は簡単。
ただし、依存関係のある演算同士はできるだけ離して置くようにする。
例えば、mulpd xmm0, xmm0 でaの2乗を計算した後、
xmm0(ていうかaの2乗)を使う演算はできるだけ離して配置したい。
そうすればレイテンシの長い命令を並列実行してくれる。
CPU内には命令を溜めておける場所もあるけど、最適化の効果は十分ある。
cmpnlepd xmm2, D4 ; xmm2 = (rr > 4) ? -1 : 0 subpd xmm0, xmm3 ; xmm0 = aa-bb+x = a', free xmm3 movmskpd eax, xmm2 addpd xmm1, xmm7 ; xmm1 = 2ab+y = b' test eax, eax jnz escaped
命令が入り交じってわかりにくいが、複素数の絶対値が4を超えたかの判定にのみ触れる。
cmpnlepdで、xmm2の値が4以下でない場合に全ビットを1にする。
「超える」ではなく「以下でない」というのは回りくどい表現だが、
レジスタの状態が通常の数でなかった場合の扱いが違う(ここでは関係ない)。
「D4」は、先に宣言しておいた定数。ここに4.0が2つ入っていて、これと比較する。
これを条件分岐に使いたいため、movmskpdで情報を汎用レジスタ(eax)に移す。
現在、SSE2でdouble値2つを同時に計算しているが、そのうち1つでも4.0を超えたら分岐する。
test命令でゼロフラグが立たない、つまり誰かが4.0を超えたときにjnzでジャンプする。
実はここで、cmpnlepdの代わりにcmplepdを使い、psubd xmm5, xmm2を入れれば、
escaped:への面倒な分岐をせずに同等の計算ができる。
だが、その1命令を削って高速化するのが今回の思想。
17命令のループだが、psubdは軽い命令なので、1%とかのレベルでしか変わらない。
reentry: dec ecx jnz brotloop jmp exit
reentry: は後で出てくる。処理完了時以外に、ここへ戻ってくるための場所。
decとjnzは、ダウンカウントループの定型文。
SSE命令はフラグを変更しないので、必要ならdecの後にaddpdとかを入れてもいい。
指定の回数(m回)に達したらループを抜け、jmpで終了処理へ飛ぶ。
escaped: movd xmm4, ecx pshufd xmm4, xmm4, 0 pand xmm4, xmm2 ; 反復回数をマスク por xmm5, xmm4 ; 発散した点へ回数を書き込む
少なくとも1つが4.0を超えた場合にここへ来る。
まず、現在のカウンタの値をxmm4へコピー。
pshufd xmm4, xmm4, 0 は、xmmレジスタの最下位32bitを全体へコピーする。
ここで、さっきのxmm2(4を超えた場合all 1)をマスクとして使う。
それを、反復回数を記録しておくxmm5へporで書き込めば、要素毎に回数を管理できる。
pxor xmm2, L1 ; ビット反転ってこれしか方法ないんだろうか andpd xmm0, xmm2 andpd xmm1, xmm2 andpd xmm6, xmm2 andpd xmm7, xmm2 ; 発散した点をクリアし二度とここに来ないようにする pxor xmm4, xmm4 pcmpeqd xmm4, xmm5 ; xmm4 = (xmm5 == 0) ? -1 : 0 pmovmskb eax, xmm4 test eax, eax ; xmm5に0が jnz reentry ; 一つでもあったらループに戻る
今度は、4を超えてないケースを残したいので、xmm2をxmm2の否定に置き換える。
SIMD命令にはnot命令がないので、仕方なく1とのxorをとることで反転させる。
ここはループの外なので、それほど速度を気にしなくていい。
4を超えたケースについては、ここで消えてもらう(0にする、もう4は超えない)。
次に、反復回数を記録しているxmm5を0と等しいか比較して、
0(回数未記録)が1つでもあったらループへ再突入する。
ecxが非0の状態でしかここへ飛んでこないことを使っている。
もし全て(と言っても2つだが)の点について計算が終わっていたら、jnzをスルーして終了処理へ。
exit: mov eax, m movd xmm4, eax pshufd xmm4, xmm4, 0 psubd xmm4, xmm5 mov eax, count movdqa [eax], xmm4 ret
計算は終わった。しかし、xmm5に入っている回数は加工が必要(ダウンカウントなので)。
xmm4をmの値で埋め、そこからxmm5を減算することで仕上がる。
それをcountのアドレスへ書き込んで終了。
ちょっと手抜きなので、回数はcount[0]とcount[2]に記録される。
これはdouble2個の処理だが、float4個の処理へも簡単に変更できる。
addpdをaddpsにするなどの機械的な変更だけで、ほぼ通る。
手元のPentiumMはSSE2のスループットが悪く、FPUと同等なのだが、
それでもレジスタの増加や表現力が高まったことによりFPU比で1.5倍高速化した。
精度が低くていいなら、4並列のSSE版で更に2倍くらい速くなる。
コードの見た目、若干最適化が甘いようにも見えるが、
PentiumMではパワー不足でこれ以上は詰め込めない感触だ。