🦀

Rustで深層学習フレームワークを開発しています

2024/12/23に公開

リポジトリ

⭐️がつくとやる気が出ます!
https://github.com/bokutotu/zenu

サンプルコード(GAN など)

https://github.com/bokutotu/zenu-examples
ちなみに上記のサンプルを実行すると、以下のような画像が生成できました。

GAN Generated Image GAN Generated Image GAN Generated Image GAN Generated Image

はじめに

  • Rust で深層学習フレームワーク Zenu を開発しています。
  • CPU / GPU 両方に対応し、型安全とメモリ安全を重視した設計です。
  • MNIST や GAN などのサンプルコードがありますので、ぜひ試してみてください。

モチベーション

Python / PyTorch での開発のつらさ

深層学習といえば Python + PyTorch が定番ですが、実際の研究・開発では以下のような問題に悩まされることが多いです。

  • 静的型がない(動的型付け)
    GPU / CPU 間のデバイス不一致や、float32 と float64 の混在など、型にまつわるバグがランタイムエラーで起こりがち。長時間学習してやっと失敗がわかることも…。

  • ランタイムエラーが学習後に発覚しがち
    大規模モデルや長時間学習を回してから「デバイスが違う」などのエラーに直面すると、とてもつらい。再学習のコストが大きい。学習が一通り終わって最後のtestを回し始めた時にruntime errorが出ると、とても辛い。

import torch

x = torch.ones((2, 2), device="cuda")
y = torch.ones((2, 2), device="cpu")

# GPU テンソルと CPU テンソルを足す → 実行時エラー
z = x + y

実行すると…

RuntimeError: Expected all tensors to be on the same device...

学習や推論の途中で気づくと精神的なダメージが大きいですよね。

Rust でやるメリット

  • 型安全 & 所有権
    Rust は静的型で、かつ型推論が優秀です。テンソルのデバイス (CPU / GPU) やスカラー型 (f32 / f64) を型レベルで区別でき、ミスマッチはコンパイルエラーに。
    例: Matrix<Owned<f32>, Dim2, Cpu>Matrix<Owned<f32>, Dim2, Nvidia> は別の型として扱うので、CPU テンソルに GPU テンソルを足すようなコードはコンパイルが通りません。

  • Cargo によるテストが標準で充実
    C++ と比べて、ユニットテストやベンチマークを導入しやすい。CMake などの外部ツールをあまり使わなくて済む。
    この利点はかなり大きいと思っています。

  • Rust が好き
    皆さんはRust好きですよね?ね?ね?じゃあ、Rust で深層学習フレームワークを作りましょう!

Zenu の概要

Zenu は、大きく分けると以下の 6 クレートに分割して開発しています。

  1. zenu

    • トップレベルクレート。利用するときはこれを指定。
    • feature = ["nvidia"] を有効にすると CUDA / cuBLAS / cuDNN を使った GPU 演算が可能。
  2. zenu-cuda

    • CUDA runtime / driver / cuBLAS / cuDNN など、NVIDIA 系の低レベル API をまとめたクレート。
    • CUDA カーネルのビルドや呼び出し周りをカプセル化。
  3. zenu-matrix

    • pythonでいうところの numpy に相当する多次元配列クレート。
    • ndarray を参考にさせていただきました。(圧倒的感謝)
    • CPU / GPU 両方に対応。
  4. zenu-autograd

    • 自動微分&演算グラフクレート。
    • PyTorch の torch.autograd みたいなイメージで、forward / backward を実装。
    • Variable 型を用いて勾配を保持。
  5. zenu-layers

    • PyTorch の torch.nn 相当の高レベルレイヤ (Linear, Conv, BatchNorm など) をまとめるクレート。
    • Module トレイトによるインターフェースを提供。
  6. zenu-optimizer

    • パラメータ更新アルゴリズム (SGD, Adam, AdamW など) をまとめるクレート。
    • PyTorch の torch.optim に近いイメージ。

クレート分割の課題

  • 細かく分けすぎて 共通の変更が複数クレートに及ぶ と、API の食い違いが起きやすい。
  • 自動微分はテンソル演算と密接なので、切り離すのが意外と大変だった。

feature="nvidia" で GPU サポートを ON にする

Cargo.toml の dependencies に次のように書きます:

[dependencies.zenu]
version = "*"
features = ["nvidia"]  # GPU 機能を有効化

これで内部的に zenu-cuda がビルドされ、Nvidia デバイスが使えるようになります。

use zenu::matrix::device::nvidia::Nvidia;

let model_gpu = SimpleModel::<f32, Nvidia>::new();

もし CPU から GPU へ転送するなら

let model_gpu = model.to::<Nvidia>();

と書くだけです。


主な機能

  • 最適化アルゴリズム: SGD / Adam / AdamW など
  • NN レイヤー: Linear, Convolution, Pooling, Dropout, BatchNorm など
  • 活性化関数: ReLU, Sigmoid, Tanh, Softmax
  • 損失関数: 二乗誤差、クロスエントロピー

MNIST を使った実装例

ここでは Zenu を使って MNIST で学習する簡単なサンプルコードを載せます。

MNISTコード例
use zenu::{
    autograd::{
        activation::relu::relu, creator::from_vec::from_vec, loss::cross_entropy::cross_entropy,
        no_train, set_train, Variable,
    },
    dataset::{train_val_split, DataLoader, Dataset},
    dataset_loader::mnist_dataset,
    layer::{layers::linear::Linear, Module},
    matrix::{
        device::{cpu::Cpu, Device},
        num::Num,
    },
    optimizer::{sgd::SGD, Optimizer},
};
use zenu_macros::Parameters;

// モデル定義
#[derive(Parameters)]
#[parameters(num=T, device=D)]
pub struct SimpleModel<T: Num, D: Device> {
    pub linear_1: Linear<T, D>,
    pub linear_2: Linear<T, D>,
}

impl<D: Device> SimpleModel<f32, D> {
    #[must_use]
    pub fn new() -> Self {
        Self {
            linear_1: Linear::new(28 * 28, 512, true),
            linear_2: Linear::new(512, 10, true),
        }
    }
}

impl<D: Device> Default for SimpleModel<f32, D> {
    fn default() -> Self {
        Self::new()
    }
}

// Module トレイト実装 (forward 計算)
impl<D: Device> Module<f32, D> for SimpleModel<f32, D> {
    type Input = Variable<f32, D>;
    type Output = Variable<f32, D>;

    fn call(&self, inputs: Variable<f32, D>) -> Variable<f32, D> {
        let x = self.linear_1.call(inputs);
        let x = relu(x);
        self.linear_2.call(x)
    }
}

// MNIST データセット
struct MnistDataset {
    data: Vec<(Vec<u8>, u8)>,
}

impl Dataset<f32> for MnistDataset {
    type Item = (Vec<u8>, u8);

    fn item(&self, item: usize) -> Vec<Variable<f32, Cpu>> {
        let (x, y) = &self.data[item];
        let x_f32 = x.iter().map(|&xi| xi as f32).collect::<Vec<_>>();
        let x = from_vec::<f32, _, Cpu>(x_f32, [784]);
        x.get_data_mut().to_ref_mut().div_scalar_assign(127.5);
        x.get_data_mut().to_ref_mut().sub_scalar_assign(1.0);

        let y_onehot = (0..10)
            .map(|i| if i == *y as usize { 1.0 } else { 0.0 })
            .collect::<Vec<_>>();
        let y = from_vec(y_onehot, [10]);

        vec![x, y]
    }

    fn len(&self) -> usize {
        self.data.len()
    }

    fn all_data(&mut self) -> &mut [Self::Item] {
        &mut self.data
    }
}

#[expect(clippy::cast_precision_loss)]
fn main() {
    // モデルを CPU デバイスで作成 (GPU の場合は SimpleModel::<f32, Nvidia>::new())
    let model = SimpleModel::<f32, Cpu>::new();

    // MNIST データ読み込み
    let (train, test) = mnist_dataset().unwrap();
    let (train, val) = train_val_split(&train, 0.8, true);

    // DataLoader の作成 (PyTorch でいう DataLoader に近い)
    let test_dataloader = DataLoader::new(MnistDataset { data: test }, 1);

    // 最適化アルゴリズム: SGD
    let optimizer = SGD::<f32, Cpu>::new(0.01);

    for epoch in 0..20 {
        set_train();  // PyTorch の with torch.no_grad() の“逆”
        let mut train_dataloader = DataLoader::new(
            MnistDataset {
                data: train.clone(),
            },
            32,
        );
        train_dataloader.shuffle();

        let mut train_loss = 0.0;
        let mut num_iter = 0;

        for batch in train_dataloader {
            let input = batch[0].clone();
            let target = batch[1].clone();
            let pred = model.call(input);
            let loss = cross_entropy(pred, target);

            let loss_asum = loss.get_data().asum();

            // バックワードとパラメータ更新
            loss.backward();
            optimizer.update(&model);
            loss.clear_grad();

            train_loss += loss_asum;
            num_iter += 1;
        }
        train_loss /= num_iter as f32;

        // バリデーション
        no_train(); // PyTorch の with torch.no_grad()
        let val_loader = DataLoader::new(MnistDataset { data: val.clone() }, 1);
        let mut val_loss = 0.0;
        let mut num_val_iter = 0;

        for batch in val_loader {
            let input = batch[0].clone();
            let target = batch[1].clone();
            let pred = model.call(input);
            let loss = cross_entropy(pred, target);

            val_loss += loss.get_data().asum();
            num_val_iter += 1;
        }
        val_loss /= num_val_iter as f32;

        println!("Epoch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}");
    }

    // テスト
    let mut test_loss = 0.0;
    let mut num_test_iter = 0;
    for batch in test_dataloader {
        let input = batch[0].clone();
        let target = batch[1].clone();
        let pred = model.call(input);
        let loss = cross_entropy(pred, target);

        test_loss += loss.get_data().asum();
        num_test_iter += 1;
    }

    println!("Test Loss: {}", test_loss / num_test_iter as f32);
}

GPU で走らせたい場合は、feature で nvidia を有効にして、Variableやモデルに対して.to::<Nvidia>() を呼び出すだけです。


他の Rust 製フレームワークとの比較

  • candle
    Rust で自前実装を頑張っているフレームワーク。Zenu と方向性は似ていますが、Metal や FlashAttention など先進的なバックエンドもサポートしていてすごい。
  • tch-rs / burn
    • tch-rs は PyTorch C++ バインディング。中身は PyTorch と同じなので実績は十分だけど、Rust の所有権や型システムを最大限活かすわけではない。
    • burn は独自 DSL を構築するアプローチが印象的。Zenu とは別のベクトルで面白い。

今後の展望

  • SIMD
    CPU での SSE / AVX などを利用した高速化
  • マルチ GPU / 分散学習
  • Transformer 系モデルの実装
  • PyTorch モデルの読み書き / ONNX 対応
  • ドキュメントの充実
    • チュートリアルや API リファレンスを拡充
  • サンプルコードの拡充
    • 画像分類、物体検出、セグメンテーションなど
    • 音声系の会社に転職したので音声系も頑張りたい
  • エラーハンドリングの改善
    • よりわかりやすいエラーメッセージを目指す

まとめ

  • Rust 製の深層学習フレームワーク Zenu を作っています。
  • CPU / GPU 両方に対応し、型安全とメモリ安全を重視した設計です。
  • MNIST や GAN などのサンプルコードがありますので、ぜひ試してみてください。
  • Issues、PR、マサカリ、大歓迎です!スターをいただけるとモチベが爆上がりします。

参考リンク

GitHubで編集を提案

Discussion