小さい文字列からASCII外のバイトを見つけたい

Pythonの内部処理で、与えられたUTF-8の文字列がASCII文字のみを含むかをテストするコードを書いています。

size_t型が8バイトの時、ポインタを char* から size_t* にキャストして、0x8080808080808080ullと論理積を取れば8バイトを一度に処理できます。 単純化すると次のようなコードになります。

#define ASCII_MASK 0x8080808080808080ull

size_t find_first_nonascii(const char *start, const char *end)
{
    const char *p = start;

    // 8バイトずつ処理
    while (p <= end - 8) {
        if (*(size_t*)p & ASCII_MASK) {
            break;
        }
        p += 8;
    }
    // 1バイトずつ処理
    while (p < end) {
        if (*p & 0x80) {
            return p-start;
        }
        p++;
    }
    return end-start;
}

1バイトずつ処理する部分が気になります。これは最下位ビットから1になっているビットを探して位置を返してくれるGCC/clangのビルトイン関数 __builtin_ctzll を使えば、前半部分をこのように高速化できます。

    while (p <= end - 8) {
        size_t x = *(size_t*)p & ASCII_MASK;
        if (x) {
            // p[0] がASCII外の時、 ctzll(0x80) == 7, (7-7)/8 == 0
            // p[1] がASCII外の時、 ctzll(0x8000) == 15, (15-7)/8 == 1
            return p-start + (__builtin_ctzll(x) - 7) / 8;
        }
        p += 8;
    }

しかし問題は後半部分です。 8バイト未満の端数部分をsize_tに格納できれば同じ高速化ができます。たとえば次のようになります。

    // (end-p)が3のとき、 3*8=24, 1ull << 24 = 0x1000000, 0x1000000-1 = 0xffffff.
    // これとmaskしてやれば、4バイト以降は無視できる。
    size_t x = *(size_t*)p & ((1ull << (end - p) * 8) - 1) & ASCII_MASK;
    if (x) {
        return p - start + (__builtin_ctzll(x) - 7) / 8;
    }
    return end - start;

これでバイト数*2回だった分岐が1回に減らせてめでたしめでたし、、、なのですが、残念ながら入力の文字列の範囲外を読んでからマスクしているだけなので、範囲外アクセスが発生しています。 試しに clang の AddressSanitizer を使ってみるとエラーが出てしまいました。

アライメントさえ合っていれば、8バイトのreadがセグメンテーション違反を起こすことは無いと思いますが、範囲外アクセスはやはり避けたいです。 そこで次のように switch 文を書いてみました。

    size_t u = (size_t)(p[0]);
    switch (end - p) {
        default:
            u |= (size_t)(p[7]) << 56ull;
        // fall through
        case 7:
            u |= (size_t)(p[6]) << 48ull;
        // fall through
        case 6:
            u |= (size_t)(p[5]) << 40ull;
        // fall through
        case 5:
            u |= (size_t)(p[4]) << 32ull;
        // fall through
        case 4:
            u |= (size_t)(p[3]) << 24;
        // fall through
        case 3:
            u |= (size_t)(p[2]) << 16;
        // fall through
        case 2:
            u |= (size_t)(p[1]) << 8;
            break;
        case 1:
            break;
    }
    if (u & ASCII_CHAR_MASK) {
        return p - start + (__builtin_ctzll(u & ASCII_CHAR_MASK) - 7) / 8;
    }

これで範囲外アクセスはなくなりました。 clang が生成するアセンブリ(intel形式)は次のような感じになります。

...
                u |= (size_t)(p[7]) << 56ull;
 147:   44 0f b6 59 07          movzx  r11d,BYTE PTR [rcx+0x7]
 14c:   49 c1 e3 38             shl    r11,0x38
 150:   4d 09 da                or     r10,r11
                u |= (size_t)(p[6]) << 48ull;
 153:   44 0f b6 59 06          movzx  r11d,BYTE PTR [rcx+0x6]
 158:   49 c1 e3 30             shl    r11,0x30
 15c:   4d 09 da                or     r10,r11
                u |= (size_t)(p[5]) << 40ull;
 15f:   44 0f b6 59 05          movzx  r11d,BYTE PTR [rcx+0x5]
 164:   49 c1 e3 28             shl    r11,0x28
 168:   4d 09 da                or     r10,r11
                u |= (size_t)(p[4]) << 32ull;
...

1バイトずつリードして、左シフトして、ORしてますね。switchのジャンプテーブルは使われているので分岐を減らす目的は達成できました。

しかし、アライメントを保証することで、範囲外リードを安全に行うような方法があればそっちのほうがいいので、その方法を現在探しています。

このブログに乗せているコードは引用を除き CC0 1.0 で提供します。