Cで継続渡し(末尾最適化バージョン)

Cで継続渡しを書いてから、もう少しがんばれば末尾最適化もできそうだということに気付いて、やってみたらなんとかできた。しかも、絶対にアセンブラが必要だろうと思っていたんだけど結果的にはアセンブラに頼らずにCの範囲内で実現することができた。さすが高級アセンブラ。

  • fib2.c
#include <stdio.h>
#include <stdlib.h>

#include "obj.h"

int main(int argc, char *argv[])
{
    // 汎用レジスタのつもり
    void *r1;
    void *r2;

    closure_t c = newClosure(0);
    c->func = (func_t) &&fib_cps_0;

    for (int i = 1; i < 10; i++) {
        addRef((object_t) c);

        // printf("Fibonacci(%d) = %d\n", i, fib_cps(i, c));

        r1 = (void*) i;
        r2 = (void*) c;
        goto fib_cps;

 fib_cps_0: // int fib_cps_0(int n, closure_t k)
        {
            int n = (int) r1;
            closure_t k = (closure_t) r2;
            release((object_t) k);

            printf("Fibonacci(%d) = %d\n", i, n);
        }
    }

    release((object_t) c);

    return EXIT_SUCCESS;

 fib_cps: // int fib_cps(int n, closure_t k)
    {
        int n = (int) r1;
        closure_t k = (closure_t) r2;
        if ((n == 1) || (n == 2)) {
            r1 = (void*) 1;
            r2 = (void*) k;
            goto *(k->func);
            // return (int) k->func(1, k);
        } else {
            closure_t c = newClosure(2);
            c->func = &&fib_cps_1;
            c->vars[0] = (void*) k;
            c->vars[1] = (void*) n;
            r1 = (void*) (n - 1);
            r2 = (void*) c;
            goto fib_cps;
            // return fib_cps(n - 1, c);
        }
    }

 fib_cps_1: // int fib_cps_1(int v1, closure_t k1)
    {
        int v1 = (int) r1;
        closure_t k1 = (closure_t) r2;

        int n = (int) k1->vars[1];
        closure_t c = newClosure(3);
        c->func = &&fib_cps_2;
        c->vars[0] = k1->vars[0];
        c->vars[1] = (void*) v1;
        c->vars[2] = (void*) n;
        release((object_t) k1);
        r1 = (void*) (n - 2);
        r2 = (void*) c;
        goto fib_cps;
        // return (int) fib_cps(n - 2, c);
    }

 fib_cps_2: // int fib_cps_2(int v2, closure_t k2)
    {
        int v2 = (int) r1;
        closure_t k2 = (closure_t) r2;

        int v1 = (int) k2->vars[1];
        closure_t k = k2->vars[0];
        release((object_t) k2);
        r1 = (void*) (v1 + v2);
        r2 = (void*) k;
        goto *(k->func);
        // return (int) k->func(v1 + v2, k);
    }
}

最初ラベルのアドレスを得る方法とかポインタ変数の指すアドレスにgotoする方法が分からなくて困った(それぞれ&&label, goto *varなどと書けばいい)。

このコードは全く実用的ではないけど、末尾呼び出しをジャンプに変換すると本当に関数呼び出しが消えるという実証にはなっていると思う。こういうことをコンパイラがやってくれると、プログラマは効率の心配をせずになんでも再帰呼び出しにしてしまうことができる。

以下は28日のfib.h, fib.cの一部を別ファイルとしたもの(内容はほとんど同じですが微妙に修正してあります)

  • obj.h
#ifndef __OBJ_H__
#define __OBJ_H__

typedef void* (*func_t)();
typedef void (*proc_t)();

typedef struct object {
    int ref_count;
    size_t size;
    proc_t destructor;
} *object_t;

void addRef(object_t obj);
void release(object_t obj);
object_t newObject(size_t size, proc_t destructor);
void deleteObject(object_t obj);

typedef struct closure {
    // inherited from object_t
    int ref_count;
    size_t size;
    proc_t destructor;
    // own data
    func_t func;
    int var_num;
    void *vars[0];
} *closure_t;

closure_t newClosure(int var_num);

#endif /* __OBJ_H__ */
  • obj.c
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>

#include "obj.h"

void addRef(object_t obj)
{
    obj->ref_count++;
}

void release(object_t obj)
{
    obj->ref_count--;
    if (obj->ref_count <= 0) {
        if (obj->destructor)
            obj->destructor(obj);
        deleteObject(obj);
    }
}

object_t newObject(size_t size, proc_t destructor)
{
    assert(size >= sizeof(struct object));

    object_t obj = malloc(size);
#ifdef DEBUG
    fprintf(stderr, "malloc: %p\n", obj);
#endif
    obj->ref_count = 0;
    obj->size = size;
    obj->destructor = destructor;
    addRef(obj);
    return obj;
}

void deleteObject(object_t obj)
{
#ifdef DEBUG
    size_t size = obj->size;
    char *p = (char*) obj;

    // filled by `de ad be ef' ... pattern
    for (int i = 0; i < size / 4; i++) {
        *p++ = '\xde'; *p++ = '\xad';
        *p++ = '\xbe'; *p++ = '\xef';
    }
    switch (size % 4) {
        case 3: p[2] = '\xbe';
        case 2: p[1] = '\xad';
        case 1: p[0] = '\xde';
    }
#endif /* DEBUG */

    free(obj);

#ifdef DEBUG
    fprintf(stderr, "free: %p\n", obj);
#endif
}

closure_t newClosure(int var_num)
{
    assert(var_num >= 0);
    size_t size = sizeof(struct closure) + sizeof(void*) * var_num;
    closure_t c = (closure_t) newObject(size, NULL);
    c->var_num = var_num;
    return c;
}