2022年振り返り
プライベートなことは Notion に書いた。
How GC in unsafe loxido works
ceronman/loxido: Rust implementation of the Lox programming language.
How to allocate object
let bar = gc.alloc(LoxString::from_string("bar".to_owned()));
LoxString::from_string
and GcObject
implementation. They are staraightforward.
pub struct LoxString { pub header: GcObject, pub s: String, pub hash: usize, } impl LoxString { pub fn from_string(s: String) -> Self { let hash = LoxString::hash_string(&s); LoxString { header: GcObject::new(ObjectType::LoxString), s, hash, } __snip__ pub struct GcObject { marked: bool, next: Option<NonNull<GcObject>>, obj_type: ObjectType, }
Gc::alloc should be doing the trick. I added a comment for each line.
pub fn alloc<T: Display + 'static>(&mut self, object: T) -> GcRef<T> { unsafe { // Alloc object in heap. let boxed = Box::new(object); // Get raw pointer from Box::into_raw. // pointer is NonNull<T>. let pointer = NonNull::new_unchecked(Box::into_raw(boxed)); // This assumes &T has GcObject as the first field with the exact sae alignhment. let mut header: NonNull<GcObject> = mem::transmute(pointer.as_ref()); // Gc.first is linked list of GcObject(=header). header.as_mut().next = self.first.take(); // The newly allocated one becomes head of the linked list. self.first = Some(header); // GcRef is a struct with one field pointer: NonNull<T>. GcRef { pointer } } }
Mark
Adding my comments inline.
fn mark_roots(&mut self) { // Mark VM stack. for &value in &self.stack[0..self.stack_len()] { self.gc.mark_value(value); } // Mark call frames. for frame in &self.frames[..self.frame_count] { self.gc.mark_object(frame.closure) } for &upvalue in &self.open_upvalues { self.gc.mark_object(upvalue); } self.gc.mark_table(&self.globals); self.gc.mark_object(self.init_string); }
Trace
mark_object
pushes in-use objects to grey_stack
. blacken_object
marks recursively in objects. Unlike Rust GC this doesn't have Trace trait. So in blacken_object
it should enum all GcRef field.
fn trace_references(&mut self) { while let Some(pointer) = self.grey_stack.pop() { self.blacken_object(pointer); } }
Sweep
if the object is not marked release the object from heap.
fn sweep(&mut self) { let mut previous: Option<NonNull<GcObject>> = None; let mut current: Option<NonNull<GcObject>> = self.first; while let Some(mut object) = current { unsafe { let object_ptr = object.as_mut(); current = object_ptr.next; if object_ptr.marked { object_ptr.marked = false; previous = Some(object); } else { if let Some(mut previous) = previous { previous.as_mut().next = object_ptr.next } else { self.first = object_ptr.next } Box::from_raw(object_ptr); } } } }
Finalize, Sweep and Rooting - Understanding Rust GC
Finalize Trait
Previously we explored Trace
in Trace - Understanding Rust GC - higepon blog. Let's look into Finalize
trait. This is really simple. Every GC
-ed object should implement this finalize.
pub trait Finalize { fn finalize(&self) {} }
Let's look at finalize implementation for array
and Option<T>
. They are empty. Hmm. I'm not quite sure why. Let me start from collect_garbage
and come back here later.
;; Array impl<T: Trace, const N: usize> Finalize for [T; N] {} ;; Option<T> impl<T: Trace> Finalize for Option<T> {}
collect_garbage
In collect_garbage
after the mark
returns unmarked objects (= not used objects) and it will call finalize_glue()
for each GC<T>
object.
let unmarked = mark(&mut st.boxes_start); if unmarked.is_empty() { return; } for node in &unmarked { Trace::finalize_glue(&(*node.this.as_ptr()).data); }
finalize_glue
calls finalize
in the Trait. So say an array is unmarked, the gc eventually calls empty trace
of the array. Now I remember that finalize
give the object a chance to any necessary work before it's deallocated. For example an object may want to close a file associated with it.
Sweep
Now we know that the gc collects unmarked objects and calls finalize for the unmarked objects. It's time to sweep
them.
unsafe fn sweep(finalized: Vec<Unmarked>, bytes_allocated: &mut usize) { let _guard = DropGuard::new(); for node in finalized.into_iter().rev() { if (*node.this.as_ptr()).header.marked.get() { continue; } let incoming = node.incoming; let mut node = Box::from_raw(node.this.as_ptr()); *bytes_allocated -= mem::size_of_val::<GcBox<_>>(&*node); *incoming = node.header.next.take(); } }
This is actually very interesting. It recreates Box from returning raw pointer using from_raw which makes sense!
Rooting
Let's get back to root
and unroot
. A Tour of Safe Tracing GC Designs in Rust - In Pursuit of Laziness has a good summary and examples of what is rooting and why we need it.
In one word: Rooting. In a garbage collector, the objects “directly” in use on the stack are the “roots”, and you need to be able to identify them.
struct Foo { bar: Option<Gc<Bar>> } // this is a root let bar = Gc::new(Bar::new()); // this is also a root let foo = Gc::new(Foo::new()); // bar should no longer be a root (but we can't detect that!) foo.bar = Some(bar); // but foo should still be a root here since it's not inside // another GC'd object let v = vec![foo];
To track root objects, the GC maintains root counts in GC_BOX
. In short GC_BOX
with root count > 0 is a root object. Rooting - Designing a GC in Rust explains it very well. Note that the count is incremented or decremented when GC<T>
object is moved.
Summary
- Finalize gives an object an opportunity to do some clean up work before it's free-ed.
- Sweep returns unmarked objects using
Box
. - To track objects directly use in stack the gc maintains root count.
Trace - Understanding Rust GC
Goal
I want to understand how Rust GC work to see if I can use it in my Scheme VM interpreter to be written in Rust.
How to use
GC-ed objects should implement Trace and Finalize. You should use Gc::new
instead of Box::new
to allocate objects in heap. Here is an example from the official document.
let x = Gc::new(1_u8); let y = Gc::new(Box::new(Gc::new(1_u8))); #[derive(Trace, Finalize)] struct Foo { a: Gc<u8>, b: u8 } let z = Gc::new(Foo {a: x.clone(), b: 1})
What is Gc<Foo>?
z variable above is Gc<Foo>
type. You can access fields of Foo like z.a
. But why? It doesn't have z field in it. Because it implements Deref Trait Rust compiler can take care of it.
impl<T: Trace + ?Sized> Deref for Gc<T> { type Target = T; #[inline] fn deref(&self) -> &T { &self.inner().value() } }
The definition of the struct GC<T>
as follows. It has only 2 fields.
pub struct Gc<T: Trace + ?Sized + 'static> { ptr_root: Cell<NonNull<GcBox<T>>>, marker: PhantomData<Rc<T>>, }
Trace Trait
Per Designing a GC in Rust - In Pursuit of Laziness the purpose of trace
is a way of walking the fields of a given object and finding inner Gc<T>
fields. For example if you have one GC<T>
object, you should be able to find all GC<T>
fields or inner fields by the tracing.
Let's look into the Trace trait. I'm not sure what root and unroot are doing yet. We'll get back to here later.
/// The Trace trait, which needs to be implemented on garbage-collected objects. pub unsafe trait Trace: Finalize { /// Marks all contained `Gc`s. unsafe fn trace(&self); /// Increments the root-count of all contained `Gc`s. unsafe fn root(&self); /// Decrements the root-count of all contained `Gc`s. unsafe fn unroot(&self); /// Runs Finalize::finalize() on this object and all /// contained subobjects fn finalize_glue(&self); }
Here is one Trace implementation of for 'static lifetime struct.
unsafe impl<T: ?Sized> Trace for &'static T { unsafe_empty_trace!(); }
unsafe_empty_trace!
is a macro which has no-op trace, root and unroot method. This makes sense because 'static lifetime indicates that the data pointed to by the reference lives for the entire lifetime of the running program. So we don't event need to track or sweep. Let's look at one more example. This is the trace for an array. I would guess this is marking all elements in the array. Let's confirm.
unsafe impl<T: Trace, const N: usize> Trace for [T; N] { custom_trace!(this, { for v in this { mark(v); } }); }
custom_trace!
is implemented as follows. It defines inline mark
method and call it in the $body
.
/// This rule implements the trace method. /// /// You define a `this` parameter name and pass in a body, which should call `mark` on every /// traceable element inside the body. The mark implementation will automatically delegate to the /// correct method on the argument. #[macro_export] macro_rules! custom_trace { ($this:ident, $body:expr) => { #[inline] unsafe fn trace(&self) { #[inline] unsafe fn mark<T: $crate::Trace + ?Sized>(it: &T) { $crate::Trace::trace(it); } let $this = self; $body }
I think $crate::Trace::trace(it);
is calling trace()
for Gc<T>
but not 100% sure.
unsafe impl<T: Trace + ?Sized> Trace for Gc<T> { #[inline] unsafe fn trace(&self) { self.inner().trace_inner(); }
Then it calls trace_inner()
for GcBox
.
/// Marks this `GcBox` and marks through its data. pub(crate) unsafe fn trace_inner(&self) { let marked = self.header.marked.get(); if !marked { self.header.marked.set(true); self.data.trace(); } }
Remember that GcBox
is the raw pointer allocated in Gc::new
and stored in ptr_root
.
pub fn new(value: T) -> Self { assert!(mem::align_of::<GcBox<T>>() > 1); unsafe { // Allocate the memory for the object let ptr = GcBox::new(value); // When we create a Gc<T>, all pointers which have been moved to the // heap no longer need to be rooted, so we unroot them. (*ptr.as_ptr()).value().unroot(); let gc = Gc { ptr_root: Cell::new(NonNull::new_unchecked(ptr.as_ptr())), marker: PhantomData, };
Let's quickly look at GcBox
definition.
pub(crate) struct GcBoxHeader { // XXX This is horribly space inefficient - not sure if we care // We are using a word word bool - there is a full 63 bits of unused data :( // XXX: Should be able to store marked in the high bit of roots? roots: Cell<usize>, next: Option<NonNull<GcBox<dyn Trace>>>, marked: Cell<bool>, } #[repr(C)] // to justify the layout computation in Gc::from_raw pub(crate) struct GcBox<T: Trace + ?Sized + 'static> { header: GcBoxHeader, data: T, }
Okay. Now I have better understanding. The trace
method just set marked=true
which means the ptr is in use. By the way who is calling the trace
method and when? The answer is collect_garbage
=> GcBox::trace_inner()
=> trace
. This is exciting! We're so close to the core of the gc. Let's look at collect_garbage
.
/// Collects garbage. fn collect_garbage(st: &mut GcState) { __snip__ unsafe fn mark(head: &mut Option<NonNull<GcBox<dyn Trace>>>) -> Vec<Unmarked> { // Walk the tree, tracing and marking the nodes let mut mark_head = *head; while let Some(node) = mark_head { __snip__ unsafe { let unmarked = mark(&mut st.boxes_start); if unmarked.is_empty() { return; }
Now we know GcState.boxes_start
is the actual starting point of mark
and it recursively marks GC<T>
objects. My next question is who's setting up boxes_start
? The answer is it is done in GcBox::new
whenever it allocates new GcBox, it maintains a link list of GcBox and set it to boxes_start
.
let gcbox = Box::into_raw(Box::new(GcBox { header: GcBoxHeader { roots: Cell::new(1), marked: Cell::new(false), next: st.boxes_start.take(), }, data: value, })); st.boxes_start = Some(unsafe { NonNull::new_unchecked(gcbox) });
And finally GcState is a thread local object as follows. So GcState is a kind of global variable which is local to a thread.
// The garbage collector's internal state. thread_local!(static GC_STATE: RefCell<GcState> = RefCell::new(GcState { bytes_allocated: 0, threshold: INITIAL_THRESHOLD, boxes_start: None, }));
Summary of Trace (Mark)
- GC-ed object should implement
Trace
trait. - The
trace
method should recursively calltrace
method for inner objects. - In mark phase the gc calls
trace
forGcBox::new
ed objects.
Next: Finalize, Sweep, root and unroot - Understanding Rust GC - higepon blog.
Mosh compiler/VM のデバッグを1週間続けた話
Mosh は R7RS に対応しつつあるので ecraven/r7rs-benchmarks というベンチマークを走らせてみた。今回は速度は気にせずにR7RS Scheme 処理系としてコードを間違いなく正しく実行できるかに注力。結局57個中で3つのベンチマークで実行に失敗することが分かった。うち2つは比較的簡単に修正できた。最後の1つが手強かったので記録を残す。
スタート
失敗するのは conform というベンチマーク。期待される結果とは違うものを Mosh が返してしまう。そして実行時間がやけに短い。conform ベンチマークのスクリプト( r7rs-benchmarks/conform.scm)を見てみるとグラフ構造を作って何かやっているらしい。正直コードを理解しようとするのは3秒で諦めた。
この時点ではデバッグが難航する事は知る由もないのでデバッグ print を入れてなぜ期待と違うかを調べようとするがすぐに挫折。なぜならば
- A) グラフ構造を扱っているので自分の中で自分を参照していてデバッグ print と相性が悪いこと。write/ss で shared structure を print できるがそれでも視認性が悪い。
- B) データ構造が大きく print しきれない
- C) そもそも何をやっているのかコードから読み取れない。
この状態で2-3時間無駄に使った。
心理的安全性
これは難航しそうだとようやく気付いたので少し落ち着くことにした。正しい状態がよく分からないのが問題なので Gauche で実行して比べることとした。次に処理内容が分からないのは良いとして、メインの処理を何となく名前から特定した。ここでようやくメインの処理にデバッグ print を入れて Gauche と比較できるようになり、ある関数で一瞬で間違った値が返っていることが分かった。
間違うポイントが分かったので勝利を確信。その関数への入力を維持したまま再現コードを小さくしていくことにした。ところがこれがかなり難しい。入力も出力もグラフなので文字列や数字を扱うのとは別の難しさがある。色々やっているうちにぐちゃぐちゃになってしまった。元に戻らなくなってしまい大反省。debug という git branch を作り少しずつ進むようにしたら急に捗ったし壊すことも無くなった。チェックポイントあるし壊しても大丈夫という心理的安全性大事。1日かけて小さなコードに絞ることができた。
(define (foo) (define (bar n) (cond ((= n 1) #t) (else (let loop ((lst '(1))) (if (null? lst) #t (and (display "=> recusive1\n") (bar 1) (display "=> loop\n") (loop (cdr lst)))))))) (bar 0) ) (foo)
このコードの (bar 1)
の呼び出しの後に (display "=> loop\n")
が呼ばれないことが分かった。これは明らかにおかしいなぜならば (bar 1)
は #t を返すから。
それで誰が悪いのか?
Scheme 処理系を書いたことのある人なら分かると思うが、これは色々怪しい。define の中に define があり named let があり末尾再帰があり。どこでバグっていてもおかしくない。
この時点でバグの原因候補は実行順に
- コンパイラ手前の S 式の変換。例)and を if に変換
- コンパイラの S式 => iform への変換
- コンパイラの iform に対する最適化
- コンパイラのコード生成部
- コンパイラの merge instructions 部
- VM
- R6RS -> バックエンド変換部
- R7RS -> R6RS 変換部
というのを切り替えないといけない。その上この辺りを触るのは10年以上ぶりである!
コンパイラと VM を調べる
最適化を OFF にすると再現しなくなったので
- 最適化そのものが間違っている
- 最適化で出てくるコードを正しく実行できないパスがある
あたりが怪しい。
pass2 の最適化の途中。1週間のデバッグを終えた今の目で見ればおかしいところは明らか。
($if ($asm NUMBER_EQUAL ($lref n[1;0]) ($const 1)) ($const #t) ($letrec ((loop[0;0] ($lambda[loop;2] (lst[2;0]) ($label #0 ($if ($asm NULL_P ($lref lst[2;0])) ($const #t) ($if ($call ($gref display) ($const "=> recusive1\n")) ($if ($call[tail-rec] ($lref bar[2;0]) ($const 1)) ($if ($call ($gref display) ($const "=> loop\n")) ($call[jump] ($call[embed] ($lambda[loop;2] (lst[2;0]) label#0) ($const (1))) ($asm CDR ($lref lst[2;0]))) ($it)) ($it)) ($it)))))) ) ($call[embed] ($lambda[loop;2] (lst[2;0]) label#0) ($const (1)))))
pass2 の最適化後の iform は以下の通り。ここで時間を使いすぎた。
($define () foo ($lambda[foo;0] () ($call[embed 4 0] ($lambda[bar;2] (n[1;0]) ($label #0 ($if ($asm NUMBER_EQUAL ($lref n[1;0]) ($const 1)) ($const #t) ($call[embed 7 0] ($lambda[loop;2] (lst[2;0]) ($label #1 ($if ($asm NULL_P ($lref lst[2;0])) ($const #t) ($if ($call ($gref display) ($const "=> recusive1\n")) ($if ($call[jump 0 0] ($call[embed 4 0] ($lambda[bar;2] (n[1;0]) label#0) ($const 0)) ($const 1)) ($if ($call ($gref display) ($const "=> loop\n")) ($call[jump 0 0] ($call[embed 7 0] ($lambda[loop;2] (lst[2;0]) label#1) ($const (1))) ($asm CDR ($lref lst[2;0]))) ($it)) ($it)) ($it))))) ($const (1)))))) ($const 0))))
結局 pass2 の最適化で local call を埋め込む処理あたりで何かがおかしい事は分かるのだが。この iform がおかしいのか。後続の処理がおかしいのか分からないので後続も見る。 実際の VM instruction 列を表示してみると。ますます分からない。
CLOSURE 81 0 #f 0 11 ((reproduce.scm 1) foo) LET_FRAME 7 CONSTANT_PUSH 0 ENTER ;; Label #0 1 REFER_LOCAL_PUSH_CONSTANT 0 1 BRANCH_NOT_NUMBER_EQUAL ;; if (= n 1) 5 CONSTANT #t LOCAL_JMP ;; goto label #1 57 LET_FRAME 5 REFER_LOCAL_PUSH 0 DISPLAY 1 CONSTANT_PUSH (1) ENTER 1 REFER_LOCAL_BRANCH_NOT_NULL 0 5 CONSTANT #t LOCAL_JMP 38 FRAME 6 CONSTANT_PUSH => recusive1 REFER_GLOBAL_CALL display ;; (display "=> recusive1\n") 1 ;; # of args TEST ;; (display ...) return value is true. So we skip the +1 next line and go to +2 line. 29 CONSTANT_PUSH ;; Come here after (display ...) call 1 SHIFTJ ;; adjust SP and FP 1 ;; depth 4 ;; diff 0 ;; How many closure to go up? LOCAL_JMP ;; Jump to label #0 -42 TEST 19 FRAME 6 CONSTANT_PUSH => loop REFER_GLOBAL_CALL display 1 TEST 10 REFER_LOCAL 0 CDR_PUSH SHIFTJ 1 1 0 LOCAL_JMP -43 LEAVE 1 LEAVE ;; label #2 1 ;; adjust stack RETURN ;; return to the code (the code is taken from the top of stack). ** But we don't have the code in the stack?*** 0 DEFINE_GLOBAL foo HALT NOP NOP NOP NOP NOP NOP NOP NOP NOP NOP NOP NOP)
動的に VM で実行される様子を stack と共に。
======================== FRAME FP|0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" ======================== REFER_GLOBAL FP|0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" ======================== CALL |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" ======================== LET_FRAME # LET_FRAME for lambda[foo] |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" FP|4: 4 # Save fp |5: "#(#(CLOSURE ...))" # Closure ======================== CONSTANT |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" FP|4: 4 |5: "#(#(CLOSURE ...))" ======================== PUSH |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" FP|4: 4 |5: "#(#(CLOSURE ...))" |6: 0 # Constant 0 ======================== ENTER # Adjust fp |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 ======================== REFER_LOCAL # a=(Constant 0) |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 ======================== PUSH |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 # Constant 0 |7: 0 # Constant 0 ======================== CONSTANT # a=(Constant 1) |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 |7: 0 ======================== BRANCH_NOT_NUMBER_EQUAL # a != stack-bottom(Constant 0) |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 # Discarded stack top. ======================== LET_FRAME # LET_FRAME for loop. |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 |7: 6 # Save fp |8: "#(#(CLOSURE ...))" # Push closure ======================== REFER_LOCAL # a=(Constant 0) REALLY??? |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 |7: 6 |8: "#(#(CLOSURE ...))" ======================== PUSH |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 |7: 6 |8: "#(#(CLOSURE ...))" |9: 0 # push (Constant 0) ======================== DISPLAY # Create a display and set it to closure. |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 |7: 6 |8: "#(#(CLOSURE ...))"# Note stack is popped. ======================== CONSTANT # a=(1) |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 |7: 6 |8: "#(#(CLOSURE ...))" ======================== PUSH |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 0 |7: 6 |8: "#(#(CLOSURE ...))" |9: 1 # (1) ======================== ENTER |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 # New FP. ======================== REFER_LOCAL # a=(1) |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 ======================== BRANCH_NOT_NULL # Go to else. |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 ======================== FRAME # Push stuff. |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 |10: "(#(#f ...)" # closure |11: 9 # fp |12: 54 # pc + n |13: "(#(CLOSURE ...)" # Current codes ======================== CONSTANT # a="=> recursive1" |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 |10: "(#(#f ...)" |11: 9 |12: 54 |13: "(#(CLOSURE ...)" ======================== PUSH |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 |10: "(#(#f ...)" |11: 9 |12: 54 |13: "(#(CLOSURE ...)" |14: "=> recusive1\n" ======================== REFER_GLOBAL # a=<display> |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 |10: "(#(#f ...)" |11: 9 |12: 54 |13: "(#(CLOSURE ...)" |14: "=> recusive1\n" ======================== CALL # call a=<display>. Note codes is now body of display. |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" |9: 1 |10: "(#(#f ...)" |11: 9 |12: 54 |13: "(#(CLOSURE ...)" FP|14: "=> recusive1\n" |15: () # display has optional-arg '() ======================== REFER_LOCAL # a="=> recursive1\n" |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" |9: 1 |10: "(#(#f ...)" |11: 9 |12: 54 |13: "(#(CLOSURE ...)" FP|14: "=> recusive1\n" |15: () ======================== BRANCH_NOT_NULL # a is not null so we go to Else. But this is really weird. |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" |9: 1 |10: "(#(#f ...)" |11: 9 |12: 54 |13: "(#(CLOSURE ...)" FP|14: "=> recusive1\n" |15: () ======================== REFER_LOCAL # a="=> recursive1" |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" |9: 1 |10: "(#(#f ...)" |11: 9 |12: 54 |13: "(#(CLOSURE ...)" FP|14: "=> recusive1\n" |15: () ======================== PUSH |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" |9: 1 |10: "(#(#f ...)" |11: 9 |12: 54 |13: "(#(CLOSURE ...)" FP|14: "=> recusive1\n" |15: () |16: "=> recusive1\n" ======================== REFER_FREE 0 # Point codes + 6 |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" |9: 1 |10: "(#(#f ...)" |11: 9 |12: 54 |13: "(#(CLOSURE ...)" FP|14: "=> recusive1\n" |15: () |16: "=> recusive1\n" ======================== TAIL_CALL # call <display> and jump to #(RETURN 1 ...) => recusive1 |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" |9: 1 |10: "(#(#f ...)" |11: 9 |12: 54 |13: "(#(CLOSURE ...)" FP|14: "=> recusive1\n" ======================== RETURN |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 # this is () ======================== TEST # Return value of display is #t |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 ======================== CONSTANT |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 ======================== PUSH |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" |6: 0 |7: 6 |8: "#(#(CLOSURE ...))" FP|9: 1 # (1) |10: 1 # 1 ======================== SHIFTJ 1 4 0 # Adjust frame for jump. Stack is now for bar call. |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 1 ======================== LOCAL_JMP |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 1 ======================== REFER_LOCAL # a=1 |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 1 ======================== PUSH |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 1 |7: 1 ======================== CONSTANT |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 1 |7: 1 ======================== BRANCH_NOT_NUMBER_EQUAL # Now 1=1. Jump to else. |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 1 ======================== CONSTANT #t |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 1 ======================== LOCAL_JMP 66 |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" |4: 4 |5: "#(#(CLOSURE ...))" FP|6: 1 ======================== LEAVE |0: "#(#(FRAME ...))" |1: 0 |2: 6 |3: "(#(FRAME ...)" ======================== RETURN ======================== HALT
これを脳内で何回も再生し iform と行ったり来たりして実行していくと (bar 1)
を実行した後に local_jump したら戻ってこれないのは自明であることに気づく。この lambda 埋め込みは間違いである。それに気づけばあとは簡単で (and
の途中で埋め込まれてしまったのは (bar 1)
呼び出しが末尾再帰と誤認されたのが原因。
この1行の修正でお仕事完了。
Scheme で MNIST 続き
Scheme で MNIST からの学び - higepon blog の続き。
f64array
前回のプロトタイプをもとに Mosh のシステムライブラリとして f64array というものを導入した。double 型の2次元の行列を表現したもの。追加した手続きは最低限のもの。
- make-f64array
- f64array-ref
- f64array-set!
- f64array-shape
- f64array-dot-product)
f64array の構築、setter/getter、dot product 。プロファイラで見てみると行列の shape を返す手続きも無視できない回数呼ばれている。おそらく行列同士の計算での broadcast で呼ばれているのだろう。
Portable なコードを書く
R7RS では処理系ごとの違いをうまく扱える仕組みが用意されているので Gauche と Mosh 両方で動くように整えてみよう。 行列の初期化で make-normal-generator を使うのだが Gauche には SRFI-194 がないようなので data.random ライブラリから似たような手続きを持ってきて rename している。ところで SRFI-194 の author は Shiro さんなのであった。
(cond-expand [gauche (import (rename (data random) (reals-normal$ make-normal-generator)))] [(library (srfi 194)) (import (only (srfi 194) make-normal-generator))])
また Mosh では .mosh.sld という拡張子のライブラリを優先して読み込むという機能があるので以下のように Mosh が読み込むライブラリ、Mosh 以外が読み込むライブラリを分けることができる。ここで行列の実装を分岐している。
そしてどちらの行列ライブラリにも共通なコードは (include "matrix-common.scm")
のように直接ファイルを include する。
というわけでめでたく Mosh と Gauche の両方で動く MNIST デモができました。他のR7RS 処理系でも少しの手直しで動くのではないかと期待している。PR 待ってます。コードは mosh/tests/mnist at master · higepon/mosh · GitHub。
Scheme で MNIST からの学び
久しぶりに Mosh Scheme に触っている。良い機会なので ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装 を参考に Scheme で動く MNIST デモを実装しようと思い立つ。MNIST は 0 から 9 までの手書き数字の認識デモ。Neural Network フレームワークのデモやチュートリアルでよく使われるあれだ。なお Scheme には Python でいうところの numpy は存在しないので必要な行列演算はすべて Scheme で実装することとする。
データサイズ
- train data 画像(28 x 28 x 1) と正解ラベル(0-9)
- train data: 60000件
- test data: 10000件
実装前の期待と予想
- 他の言語・フレームワークでデモ可能だったのだから CPU 上で十分速く動作する。
- 比較的小さなモデル・train data なので pure Scheme 実装で大丈夫だろう。
- matrix-mul がボトルネックになるかもしれない。
技術的な選択
- Matrix は Vector の Vector として表現する。Uniform Vector (格納する型が決まっているもの)もありだが Mosh では未実装なので見送り。
- 行列演算の高速実装は特に目指さない。ナイーブな実装で良い。
matrix-mul の実装
特に工夫のないナイーブな実装。3重ループ。
(define (matrix-mul a b) (unless (= (matrix-shape a 1) (matrix-shape b 0)) (error "matrix-mul shapes don't match" (matrix-shape a) (matrix-shape b))) (let* ([nrows (matrix-shape a 0)] [ncols (matrix-shape b 1)] [m (matrix-shape a 1)] [mat (matrix nrows ncols)]) (define (mul row col) (let loop ([k 0] [ret 0]) (if (= k m) ret (loop (+ k 1) (+ ret (* (mat-at a row k) (mat-at b k col))))))) (do ((i 0 (+ i 1))) ((= i nrows) mat) (do ((j 0 (+ j 1))) ((= j ncols)) (mat-at mat i j (mul i j))))))
結果1
2層の NN を hidden_size = 10 batch_size=3 で動かしてみたが 1 batch 回したプロファイラの結果。matrix-mul は 32165 回呼ばれ 95%(98 秒)を占めている。
Time% msec calls name location 95 98500 32165 (matrix-mul a b) <transcoded-textual-input-port <binary-input-port tests/mnist.scm>>:175 1 1140 64344 (matrix-map proc a) <transcoded-textual-input-port <binary-input-port tests/mnist.scm>>:96
結果2
1 は明らかに遅いので C++ 側で実装してみる。98秒->75秒だが思ったよりも速くない。逆に言えば VM ループは思ったよりも速い。
92 74590 32165 matrix-mul 1 680 64543 (matrix-element-wise op a...) <transcoded-textual-input-port <binary-input-port tests/mnist.scm>>:186
ここで C++ 側の実装を注意して見る必要がある。途中の計算結果の保持に Scheme Object を使っていること。行列に Fixnum もしくは Flonum が存在するので型チェックをしてくれる Arithmetic::add と mul を利用して計算していること。これだと most inner loop で heap allocation が発生してしまう。
Object ret = Object::makeVector(numRowsA); for (size_t i = 0; i < numRowsA; i++) { ret.toVector()->set(i, Object::makeVector(numColsB)); } for (size_t i = 0; i < numRowsA; i++) { for (size_t j = 0; j < numColsB; j++) { Object value = Object::makeFlonum(0.0); for (size_t k = 0; k < numColsA; k++) { Object aik = a->ref(i).toVector()->ref(k); Object bkj = b->ref(k).toVector()->ref(j); value = Arithmetic::add(value, Arithmetic::mul(aik, bkj)); } ret.toVector()->ref(i).toVector()->set(j, value); } }
結果3
もしも Uniform Vector なら内部データの型チェックは不要になり、すべての計算を double のまま行うことが可能なはず。これを模した適当な実装をやってみる。不正なデータを入力されたらインタプリタが落ちてしまうがデモなのでよしとする。 98秒が60m秒になり実用的な速度に落ち着いた。heap allocation は避けたいと思わせる結果。
1 60 32165 matrix-mul
double value = 0.0; for (size_t k = 0; k < numColsA; k++) { Object aa = a->ref(i).toVector()->ref(k); Object bb = b->ref(k).toVector()->ref(j); double aik = aa.isFixnum() ? aa.toFixnum() : aa.toFlonum()->value(); double bkj = bb.isFixnum() ? bb.toFixnum() : bb.toFlonum()->value(); value += aik * bkj; } ret.toVec
学びと改善のアイデア
- Matrix を Vector で表現すると型チェックのオーバーヘッドがある
- Matrix の内部表現は double/int の2次元配列とし C++ でアクセサと mul を実装すれば良い。
- matrix-mul 以外の行列操作は Scheme 側で書いてもほぼ問題なさそう。多少遅くても呼ばれる回数が matrix-mul と比べて極端に少ない。例)add, element-wise multiply, transpose, shape。
- matrix-mul は SIMD などでもっと速くなるだろう。
余談
- numpy は歴史が長いだけあってよく出来ている。
- オブジェクト指向の操作の方がコードが読みやすい場合がある
a.t
は a の転置。(t a)
だと若干意味合いが違うし広すぎる。 - numpy の broadcasting の仕組みは便利だし必須部品であることがよく分かった。スカラ lr と 行列A の掛け算とか。
- numpy の subset が SRFI で提案されたら面白いと思う。ただ上記の理由により Scheme だけで実装すると遅いと思うので悩ましい。各処理系が独自に実装することを期待はできなそうな。
- 現実的には Tensorflow フロントエンドで Scheme が使えれば良いのかもしれないが、それなら Python で十分。
- Scheme で気軽に機械学習が楽しめるという Path は意外と遠い。Python 1強である理由も分かった。