Rustで深層学習フレームワークを開発しています
リポジトリ
⭐️がつくとやる気が出ます!
サンプルコード(GAN など)
ちなみに上記のサンプルを実行すると、以下のような画像が生成できました。
はじめに
- Rust で深層学習フレームワーク Zenu を開発しています。
- CPU / GPU 両方に対応し、型安全とメモリ安全を重視した設計です。
- MNIST や GAN などのサンプルコードがありますので、ぜひ試してみてください。
モチベーション
Python / PyTorch での開発のつらさ
深層学習といえば Python + PyTorch が定番ですが、実際の研究・開発では以下のような問題に悩まされることが多いです。
-
静的型がない(動的型付け)
GPU / CPU 間のデバイス不一致や、float32 と float64 の混在など、型にまつわるバグがランタイムエラーで起こりがち。長時間学習してやっと失敗がわかることも…。 -
ランタイムエラーが学習後に発覚しがち
大規模モデルや長時間学習を回してから「デバイスが違う」などのエラーに直面すると、とてもつらい。再学習のコストが大きい。学習が一通り終わって最後のtestを回し始めた時にruntime errorが出ると、とても辛い。
import torch
x = torch.ones((2, 2), device="cuda")
y = torch.ones((2, 2), device="cpu")
# GPU テンソルと CPU テンソルを足す → 実行時エラー
z = x + y
実行すると…
RuntimeError: Expected all tensors to be on the same device...
学習や推論の途中で気づくと精神的なダメージが大きいですよね。
Rust でやるメリット
-
型安全 & 所有権
Rust は静的型で、かつ型推論が優秀です。テンソルのデバイス (CPU / GPU) やスカラー型 (f32 / f64) を型レベルで区別でき、ミスマッチはコンパイルエラーに。
例:Matrix<Owned<f32>, Dim2, Cpu>
とMatrix<Owned<f32>, Dim2, Nvidia>
は別の型として扱うので、CPU テンソルに GPU テンソルを足すようなコードはコンパイルが通りません。 -
Cargo によるテストが標準で充実
C++ と比べて、ユニットテストやベンチマークを導入しやすい。CMake などの外部ツールをあまり使わなくて済む。
この利点はかなり大きいと思っています。 -
Rust が好き
皆さんはRust好きですよね?ね?ね?じゃあ、Rust で深層学習フレームワークを作りましょう!
Zenu の概要
Zenu は、大きく分けると以下の 6 クレートに分割して開発しています。
-
zenu
- トップレベルクレート。利用するときはこれを指定。
-
feature = ["nvidia"]
を有効にすると CUDA / cuBLAS / cuDNN を使った GPU 演算が可能。
-
zenu-cuda
- CUDA runtime / driver / cuBLAS / cuDNN など、NVIDIA 系の低レベル API をまとめたクレート。
- CUDA カーネルのビルドや呼び出し周りをカプセル化。
-
zenu-matrix
- pythonでいうところの numpy に相当する多次元配列クレート。
- ndarray を参考にさせていただきました。(圧倒的感謝)
- CPU / GPU 両方に対応。
-
zenu-autograd
- 自動微分&演算グラフクレート。
- PyTorch の
torch.autograd
みたいなイメージで、forward / backward を実装。 -
Variable
型を用いて勾配を保持。
-
zenu-layers
- PyTorch の
torch.nn
相当の高レベルレイヤ (Linear, Conv, BatchNorm など) をまとめるクレート。 -
Module
トレイトによるインターフェースを提供。
- PyTorch の
-
zenu-optimizer
- パラメータ更新アルゴリズム (SGD, Adam, AdamW など) をまとめるクレート。
- PyTorch の
torch.optim
に近いイメージ。
クレート分割の課題
- 細かく分けすぎて 共通の変更が複数クレートに及ぶ と、API の食い違いが起きやすい。
- 自動微分はテンソル演算と密接なので、切り離すのが意外と大変だった。
feature="nvidia"
で GPU サポートを ON にする
Cargo.toml
の dependencies に次のように書きます:
[dependencies.zenu]
version = "*"
features = ["nvidia"] # GPU 機能を有効化
これで内部的に zenu-cuda
がビルドされ、Nvidia
デバイスが使えるようになります。
use zenu::matrix::device::nvidia::Nvidia;
let model_gpu = SimpleModel::<f32, Nvidia>::new();
もし CPU から GPU へ転送するなら
let model_gpu = model.to::<Nvidia>();
と書くだけです。
主な機能
- 最適化アルゴリズム: SGD / Adam / AdamW など
- NN レイヤー: Linear, Convolution, Pooling, Dropout, BatchNorm など
- 活性化関数: ReLU, Sigmoid, Tanh, Softmax
- 損失関数: 二乗誤差、クロスエントロピー
MNIST を使った実装例
ここでは Zenu を使って MNIST で学習する簡単なサンプルコードを載せます。
MNISTコード例
use zenu::{
autograd::{
activation::relu::relu, creator::from_vec::from_vec, loss::cross_entropy::cross_entropy,
no_train, set_train, Variable,
},
dataset::{train_val_split, DataLoader, Dataset},
dataset_loader::mnist_dataset,
layer::{layers::linear::Linear, Module},
matrix::{
device::{cpu::Cpu, Device},
num::Num,
},
optimizer::{sgd::SGD, Optimizer},
};
use zenu_macros::Parameters;
// モデル定義
#[derive(Parameters)]
#[parameters(num=T, device=D)]
pub struct SimpleModel<T: Num, D: Device> {
pub linear_1: Linear<T, D>,
pub linear_2: Linear<T, D>,
}
impl<D: Device> SimpleModel<f32, D> {
#[must_use]
pub fn new() -> Self {
Self {
linear_1: Linear::new(28 * 28, 512, true),
linear_2: Linear::new(512, 10, true),
}
}
}
impl<D: Device> Default for SimpleModel<f32, D> {
fn default() -> Self {
Self::new()
}
}
// Module トレイト実装 (forward 計算)
impl<D: Device> Module<f32, D> for SimpleModel<f32, D> {
type Input = Variable<f32, D>;
type Output = Variable<f32, D>;
fn call(&self, inputs: Variable<f32, D>) -> Variable<f32, D> {
let x = self.linear_1.call(inputs);
let x = relu(x);
self.linear_2.call(x)
}
}
// MNIST データセット
struct MnistDataset {
data: Vec<(Vec<u8>, u8)>,
}
impl Dataset<f32> for MnistDataset {
type Item = (Vec<u8>, u8);
fn item(&self, item: usize) -> Vec<Variable<f32, Cpu>> {
let (x, y) = &self.data[item];
let x_f32 = x.iter().map(|&xi| xi as f32).collect::<Vec<_>>();
let x = from_vec::<f32, _, Cpu>(x_f32, [784]);
x.get_data_mut().to_ref_mut().div_scalar_assign(127.5);
x.get_data_mut().to_ref_mut().sub_scalar_assign(1.0);
let y_onehot = (0..10)
.map(|i| if i == *y as usize { 1.0 } else { 0.0 })
.collect::<Vec<_>>();
let y = from_vec(y_onehot, [10]);
vec![x, y]
}
fn len(&self) -> usize {
self.data.len()
}
fn all_data(&mut self) -> &mut [Self::Item] {
&mut self.data
}
}
#[expect(clippy::cast_precision_loss)]
fn main() {
// モデルを CPU デバイスで作成 (GPU の場合は SimpleModel::<f32, Nvidia>::new())
let model = SimpleModel::<f32, Cpu>::new();
// MNIST データ読み込み
let (train, test) = mnist_dataset().unwrap();
let (train, val) = train_val_split(&train, 0.8, true);
// DataLoader の作成 (PyTorch でいう DataLoader に近い)
let test_dataloader = DataLoader::new(MnistDataset { data: test }, 1);
// 最適化アルゴリズム: SGD
let optimizer = SGD::<f32, Cpu>::new(0.01);
for epoch in 0..20 {
set_train(); // PyTorch の with torch.no_grad() の“逆”
let mut train_dataloader = DataLoader::new(
MnistDataset {
data: train.clone(),
},
32,
);
train_dataloader.shuffle();
let mut train_loss = 0.0;
let mut num_iter = 0;
for batch in train_dataloader {
let input = batch[0].clone();
let target = batch[1].clone();
let pred = model.call(input);
let loss = cross_entropy(pred, target);
let loss_asum = loss.get_data().asum();
// バックワードとパラメータ更新
loss.backward();
optimizer.update(&model);
loss.clear_grad();
train_loss += loss_asum;
num_iter += 1;
}
train_loss /= num_iter as f32;
// バリデーション
no_train(); // PyTorch の with torch.no_grad()
let val_loader = DataLoader::new(MnistDataset { data: val.clone() }, 1);
let mut val_loss = 0.0;
let mut num_val_iter = 0;
for batch in val_loader {
let input = batch[0].clone();
let target = batch[1].clone();
let pred = model.call(input);
let loss = cross_entropy(pred, target);
val_loss += loss.get_data().asum();
num_val_iter += 1;
}
val_loss /= num_val_iter as f32;
println!("Epoch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}");
}
// テスト
let mut test_loss = 0.0;
let mut num_test_iter = 0;
for batch in test_dataloader {
let input = batch[0].clone();
let target = batch[1].clone();
let pred = model.call(input);
let loss = cross_entropy(pred, target);
test_loss += loss.get_data().asum();
num_test_iter += 1;
}
println!("Test Loss: {}", test_loss / num_test_iter as f32);
}
GPU で走らせたい場合は、feature で nvidia を有効にして、Variableやモデルに対して.to::<Nvidia>()
を呼び出すだけです。
他の Rust 製フレームワークとの比較
-
candle
Rust で自前実装を頑張っているフレームワーク。Zenu と方向性は似ていますが、Metal や FlashAttention など先進的なバックエンドもサポートしていてすごい。 -
tch-rs / burn
- tch-rs は PyTorch C++ バインディング。中身は PyTorch と同じなので実績は十分だけど、Rust の所有権や型システムを最大限活かすわけではない。
- burn は独自 DSL を構築するアプローチが印象的。Zenu とは別のベクトルで面白い。
今後の展望
-
SIMD
CPU での SSE / AVX などを利用した高速化 - マルチ GPU / 分散学習
- Transformer 系モデルの実装
- PyTorch モデルの読み書き / ONNX 対応
-
ドキュメントの充実
- チュートリアルや API リファレンスを拡充
-
サンプルコードの拡充
- 画像分類、物体検出、セグメンテーションなど
- 音声系の会社に転職したので音声系も頑張りたい
-
エラーハンドリングの改善
- よりわかりやすいエラーメッセージを目指す
まとめ
- Rust 製の深層学習フレームワーク Zenu を作っています。
- CPU / GPU 両方に対応し、型安全とメモリ安全を重視した設計です。
- MNIST や GAN などのサンプルコードがありますので、ぜひ試してみてください。
- Issues、PR、マサカリ、大歓迎です!スターをいただけるとモチベが爆上がりします。
参考リンク
-
ndarray
- zenu-matrix の開発にあたり、多次元配列の実装でお世話になりました。
-
ゼロから作る Deep Learning 3
- 自動微分の実装で参考にしました。
-
【作品紹介】Common Lispで深層学習フレームワークを0から作ってる話
- 同じく「自作フレームワークをやってみるぞ!」という気持ちになったきっかけの記事です。
Discussion