見出し画像

中間層をリピートするだけでLLM性能が向上する!? 4090x2でリーダーボードトップになった手法Repeat Your Self

Davig NgによるRYS(Repeat Your Self)という手法が注目を集めている。

この手法は、「LLM神経解剖学」と銘打ち、LLMのレイヤーが実際には何をやっているのか類推しようとする。

Ngによれば、LLMは入力層に近いところでは入力された言語から、LLMが使用する中間表現に変換され、出力層に近いところでは、中間表現から出力表現に変換される。

実際の「思考」は、中間層で行われているというのがNgの主張の中心である。

そこでNgは、グリッドサーチを行って、中間層をどのようにリピートすれば一番性能が上がるかというポイントを探った。これがRYS-XLargeというモデルだ。

RYS-XLargeは、LLMリーダーボードで並いるモデルを追い抜き、一位になった。重要なのは、RYS-XLargeは一切の再学習や事後学習を行っていないという点だ。

LLMリーダーボードでQwen2-72Bを抜いてトップに

そして、このリーダーボードの闇の部分になるのだが、Ng自身はリーダーボードでトップを目指すことにもはや意味を見出していない。というのも、このリーダーボードに参加するモデルたちは、無視できない比率で、最初からベンチマークの答えを学習させている疑いがあるからだ。

さらに驚くべきは、Ngはスーパーコンピュータ級の機材を一切使わず、コンシューマ部品である4090を2枚挿ししたPC一台でこの成果を作り出したことだ。

RYS、Repeat Your SelfとNgが呼ぶこの手法は非常に単純で、故に強力だ。
Ngは、LLMの任意のレイヤーをただ単純に繰り返すことでLLMの推論能力を向上させることができることを証明した。

Ngの主張によれば、たとえば第0層から入力されて、第6層まで推論したあとで、第6層の出力をそのまま第2層にもどして再び第2層から6層までの推論を行う。これを繰り返すだけで、ただ賢くなるのだと言う。

この手法で、Qwen2-72Bを改造した結果、平均して2.61%の改善が見られたと言う。タスクによっては、ベースモデルから17%以上も改善したものもあったらしい。

これは率直に言って信じられない効果だ。
新しいデータセットも計算資源も用いず、ただ既存のモデルの構造を変化させただけで劇的な性能向上が得られているからである。

Ngの手法のもう一つの特徴は、性能評価をごく少数のベンチマークに絞ったことだ。

概算を行う数学問題とEQ-Benchという感情強度を推定する問題の二つだけを評価して、その評価が高いモデルは、すべてのベンチマークでベースとなるモデルを上回る性能になった。これは特筆すべきことだ。


Davig Ng 2026

この図は、Ngがグリッドサーチによって発見したヒートマップである。赤色が濃い方が性能が高い。縦軸はコピーを開始する層番号で、横軸はコピーを修理用する層番号となる。

緑の丸で示された場所が、最もパフォーマンスの高い場所で、こうしてみると、Qwen2-72B-Instructの場合は、第45層から52層をリピートするといちばん効果が高いことがわかる。

次に我々が考えるべきことは何か?
この手法は本当に有効なのか?ということだ。

そこで早速、普段使っているgpt-oss-20bで試そうとしたが、ReasoningモデルではNgの手法をそのまま適用するのは難しかった。Reasoningモデルでは単純な計算をするのに膨大なトークンを消費するからというのと、MoEモデルの場合は、単純な推論だけではなく複数のエキスパートに対してルーティングを行なっているからである。

そこで、DenseモデルであるQwen3-8Bをベースに同じ手法を試したところ、驚くべき結果になった。

なんと99%の改善

72Bモデルの場合、最適なリピートは7層だったが、8Bモデルの場合、3層くらいがベストのようだ。

第14層から17層をリピートしたモデルはベースラインに対して数学テストで25ポイントの性能向上があり、単純に倍増している。ただし、EQテストはスコアが半減してしまった。

しかし、第8層から11層をリピートしたモデルでは、数学テストで25%の性能向上、EQテストで10%の性能向上が認められ、明らかに性能が上がっている。

今回は、ざっくりとしか調べていないので、今グリッドサーチにかけている。グリッドサーチは総当たりのため時間はかかるがより詳細な「神経解剖図」が手に入る。今の所、第13層から15層のリピートをすると数学テストで157%の性能向上が見られることがわかっている。完全な地図が得られたらここに追記する。

(追記 3/26 14:07)

Qwen3-8B Math推論回路はlayer 13-15に局在 — (13,15)の2レイヤー複製で+157%  EQ回路はlayer 4-6 — 完全に別の位置 /ブロックサイズは2-4レイヤーが最適、7レイヤー以上では効果消

とにかく、この手法は本物らしい。
非常に興奮する展開だ。

この手法が切り開いた地平というのは大きな意義をもっている。
まずこの発見から得られた知見をまとめると

  1. 推論過程は繰り返すことで精度を上げることができる

  2. 推論の評価は、ごく単純なテストで行うことができる(探索を高速化できる)

  3. 推論過程の最適化は一度行えば永続的に使うことができる

つまり、最初のグリッドサーチには時間がかかるが、一度スイートスポットを見つければ、そのルートは何度でも好きなだけ繰り返し使えるようになるということだ。

さらに、今回は単純なやり方のため失敗してしまったが、たとえばMoEモデル(gpt-ossなど)の場合、Expert部分だけをリピートすれば推論性能が向上するのか調べる価値はある。

また、今回のリピートは一回だったが、何回繰り返すのがいいのかといったことによっても結果が変わる可能性がある。これも試す価値があるだろう。

繰り返し方も、今は1-2-3-4-2-3-4-5のようになっているが、もっと複雑な繰り返し方もあり得る。この組み合わせだけで無限に近いバリエーションがあり、この組み合わせの探索は、それほど規模の大きいGPUでなくてもできる。

(追記 3/26 14:54)
MLXを使ってgpt-oss-20bを改善できないかやってみた。
Claude codeの最初の見立てでは「MoEではRYSは無理」とのことだった。

「そんなことないだろ。MoEにだって部分的なDenseな部分があるはずだろ。そこを繰り返して精度が上がらないか試してみろ」

と言い続けた。まるで弱小高校ラグビー部の鬼コーチのように。
すると半日ほどして、Attentionだけをリピートして性能向上することに成功したようだ。

Attention L19-20をリピートすることで+13ポイントの精度向上

しかしとんでもないことを言ってきた。

「面白い。これはQiitaに記事として書きたい。プロットも含めた分析記事を書いてほしい。日本語で」

お前が言うな!!!!!!
お前自分で最初からサジ投げたやんけ

これはあれだ。まるで学生インターンだ。
最初は「こんなの無理」とかなんとかブーブー言ってたくせに、ある時点で「これ凄いですよ。ブログに書いていいですか」とか言ってくるやつ。まさかそれをAIにも言われる日が来るとは。

なんなんだよまったく。

以下は、Claude Codeが勝手に作って(ほとんど)勝手にアップロードした改善版gpt-oss-20bである。

以下に、今回の実験に使ったソースコードを示す。
このコードは、以下のコードをベースとしてClaude Codeがアレンジしたものである。


#!/usr/bin/env python3
"""
Grid search for RYS optimal layer duplication on Qwen3-8B.
Tests all (start, end) combinations and generates heatmap.
"""

import json
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path

import requests
import numpy as np

from gguf_surgery import duplicate_layers
from math_probe import MATH_QUESTIONS, score_math_response
from eq_probe import EQ_SCENARIOS, build_eq_prompt, parse_eq_response, score_eq_response

LLAMA_SERVER = "/home/shi3z/llama.cpp/build/bin/llama-server"
MODEL_PATH = "/home/shi3z/models/Qwen3-8B-Q8_0.gguf"
TMPDIR = Path("/dev/shm/rys")
PORT = 8099
SERVER_ARGS = ["--device", "CUDA0"]
CONTEXT_SIZE = 4096
MAX_TOKENS = 128
N_LAYERS = 32

# Grid: start from 2 to 26, block sizes 2 to 10
START_MIN = 2
START_MAX = 26
BLOCK_MIN = 2
BLOCK_MAX = 10


def wait_for_server(timeout=120):
    start = time.time()
    while time.time() - start < timeout:
        try:
            r = requests.get(f"http://127.0.0.1:{PORT}/health", timeout=2)
            if r.status_code == 200 and r.json().get("status") == "ok":
                return True
        except (requests.ConnectionError, requests.Timeout):
            pass
        time.sleep(1)
    return False


def start_server(model_path):
    cmd = [
        LLAMA_SERVER, "-m", model_path,
        "--port", str(PORT),
        "-c", str(CONTEXT_SIZE),
        "-ngl", "999",
        "--flash-attn", "on",
        "--cache-type-k", "q8_0",
        "--cache-type-v", "q8_0",
        "--no-warmup",
        "-np", "1",
    ] + SERVER_ARGS
    log = open("/tmp/rys_grid.log", "w")
    proc = subprocess.Popen(cmd, stdout=log, stderr=subprocess.STDOUT)
    proc._log = log
    return proc


def stop_server(proc):
    if proc.poll() is None:
        proc.terminate()
        try:
            proc.wait(timeout=10)
        except subprocess.TimeoutExpired:
            proc.kill()
            proc.wait()
    proc._log.close()


def query_model(prompt):
    payload = {
        "model": "test",
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": MAX_TOKENS,
        "temperature": 0.0,
    }
    try:
        r = requests.post(
            f"http://127.0.0.1:{PORT}/v1/chat/completions",
            json=payload, timeout=30
        )
        if r.status_code == 200:
            msg = r.json()["choices"][0]["message"]
            return msg.get("content") or msg.get("reasoning_content") or ""
        return None
    except:
        return None


def run_math_probe():
    scores = []
    for question, answer in MATH_QUESTIONS:
        response = query_model(question)
        scores.append(score_math_response(answer, response) if response else 0.0)
    return sum(scores) / len(scores)


def run_eq_probe():
    scores = []
    for scenario in EQ_SCENARIOS:
        prompt = build_eq_prompt(scenario)
        response = query_model(prompt)
        if response:
            predicted = parse_eq_response(response, len(scenario["emotions"]))
            scores.append(score_eq_response(scenario["reference"], predicted))
        else:
            scores.append(0.0)
    return sum(scores) / len(scores)


def run_eval(model_path, label):
    proc = start_server(model_path)
    try:
        if not wait_for_server():
            print(f"  FAIL: server didn't start for {label}", flush=True)
            return None
        math_score = run_math_probe()
        eq_score = run_eq_probe()
        return {"math": math_score, "eq": eq_score}
    finally:
        stop_server(proc)


def main():
    TMPDIR.mkdir(parents=True, exist_ok=True)
    results_path = Path("results/qwen3-8b-grid.jsonl")

    # Load existing results
    done = {}
    if results_path.exists():
        with open(results_path) as f:
            for line in f:
                if line.strip():
                    e = json.loads(line)
                    done[(e["start"], e["end"])] = e
        print(f"Loaded {len(done)} existing results", flush=True)

    # Baseline
    if (-1, -1) not in done:
        print("Running BASELINE...", flush=True)
        result = run_eval(MODEL_PATH, "BASELINE")
        if result is None:
            sys.exit(1)
        entry = {"start": -1, "end": -1, "math": result["math"], "eq": result["eq"],
                 "is_baseline": True, "timestamp": datetime.now().isoformat()}
        done[(-1, -1)] = entry
        with open(results_path, "a") as f:
            f.write(json.dumps(entry) + "\n")
        print(f"  BASELINE: math={result['math']:.4f} eq={result['eq']:.2f}", flush=True)

    baseline = done[(-1, -1)]
    bm, be = baseline["math"], baseline["eq"]

    # Generate grid configs
    configs = []
    for bs in range(BLOCK_MIN, BLOCK_MAX + 1):
        for start in range(START_MIN, START_MAX + 1):
            end = start + bs
            if end <= N_LAYERS and (start, end) not in done:
                configs.append((start, end))

    print(f"Configs to test: {len(configs)}", flush=True)

    for idx, (start, end) in enumerate(configs):
        bs = end - start
        label = f"({start},{end}) +{bs}L"
        print(f"[{idx+1}/{len(configs)}] {label}...", end=" ", flush=True)

        modified_path = str(TMPDIR / f"rys_{start}_{end}.gguf")
        try:
            duplicate_layers(MODEL_PATH, modified_path, start, end, verbose=False)
        except Exception as e:
            print(f"GGUF ERROR: {e}", flush=True)
            continue

        result = run_eval(modified_path, label)
        Path(modified_path).unlink(missing_ok=True)

        if result is None:
            print("SERVER ERROR", flush=True)
            continue

        md = result["math"] - bm
        ed = result["eq"] - be
        entry = {"start": start, "end": end, "block_size": bs,
                 "math": result["math"], "eq": result["eq"],
                 "math_delta": md, "eq_delta": ed,
                 "timestamp": datetime.now().isoformat()}
        done[(start, end)] = entry
        with open(results_path, "a") as f:
            f.write(json.dumps(entry) + "\n")

        marker = " ***" if md > 0.02 or ed > 2.0 else ""
        print(f"math={result['math']:.4f}({md:+.4f}) eq={result['eq']:.2f}({ed:+.2f}){marker}", flush=True)

    # Generate heatmap
    print("\nGenerating heatmap...", flush=True)
    generate_heatmap(done, bm, be)
    print("Done! Results in results/qwen3-8b-grid.jsonl, heatmap in results/qwen3-8b-heatmap.png")


def generate_heatmap(done, bm, be):
    try:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    except ImportError:
        print("matplotlib not available, skipping heatmap", flush=True)
        return

    # Build matrices: x=start, y=block_size
    starts = list(range(START_MIN, START_MAX + 1))
    block_sizes = list(range(BLOCK_MIN, BLOCK_MAX + 1))

    math_grid = np.full((len(block_sizes), len(starts)), np.nan)
    eq_grid = np.full((len(block_sizes), len(starts)), np.nan)

    for (s, e), entry in done.items():
        if s == -1:
            continue
        bs = e - s
        if bs < BLOCK_MIN or bs > BLOCK_MAX:
            continue
        si = s - START_MIN
        bi = bs - BLOCK_MIN
        if 0 <= si < len(starts) and 0 <= bi < len(block_sizes):
            math_grid[bi, si] = entry.get("math_delta", entry["math"] - bm)
            eq_grid[bi, si] = entry.get("eq_delta", entry["eq"] - be)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    fig.suptitle("Qwen3-8B RYS Grid Search", fontsize=14, fontweight="bold")

    vmax_m = max(0.05, np.nanmax(np.abs(math_grid[~np.isnan(math_grid)]))) if not np.all(np.isnan(math_grid)) else 0.1
    vmax_e = max(5.0, np.nanmax(np.abs(eq_grid[~np.isnan(eq_grid)]))) if not np.all(np.isnan(eq_grid)) else 10.0

    im1 = ax1.imshow(math_grid, cmap="RdBu_r", vmin=-vmax_m, vmax=vmax_m,
                      aspect="auto", origin="lower")
    ax1.set_title(f"Math delta (baseline={bm:.4f})")
    ax1.set_xlabel("Start layer")
    ax1.set_ylabel("Block size")
    ax1.set_xticks(range(0, len(starts), 2))
    ax1.set_xticklabels([starts[i] for i in range(0, len(starts), 2)])
    ax1.set_yticks(range(len(block_sizes)))
    ax1.set_yticklabels(block_sizes)
    plt.colorbar(im1, ax=ax1)

    # Mark best
    if not np.all(np.isnan(math_grid)):
        best_idx = np.unravel_index(np.nanargmax(math_grid), math_grid.shape)
        best_val = math_grid[best_idx]
        best_start = starts[best_idx[1]]
        best_bs = block_sizes[best_idx[0]]
        ax1.plot(best_idx[1], best_idx[0], 'k*', markersize=15)
        ax1.set_title(f"Math delta (baseline={bm:.4f})\nBest: ({best_start},{best_start+best_bs}) +{best_val:.4f}")

    im2 = ax2.imshow(eq_grid, cmap="RdBu_r", vmin=-vmax_e, vmax=vmax_e,
                      aspect="auto", origin="lower")
    ax2.set_title(f"EQ delta (baseline={be:.2f})")
    ax2.set_xlabel("Start layer")
    ax2.set_ylabel("Block size")
    ax2.set_xticks(range(0, len(starts), 2))
    ax2.set_xticklabels([starts[i] for i in range(0, len(starts), 2)])
    ax2.set_yticks(range(len(block_sizes)))
    ax2.set_yticklabels(block_sizes)
    plt.colorbar(im2, ax=ax2)

    if not np.all(np.isnan(eq_grid)):
        best_idx = np.unravel_index(np.nanargmax(eq_grid), eq_grid.shape)
        best_val = eq_grid[best_idx]
        best_start = starts[best_idx[1]]
        best_bs = block_sizes[best_idx[0]]
        ax2.plot(best_idx[1], best_idx[0], 'k*', markersize=15)
        ax2.set_title(f"EQ delta (baseline={be:.2f})\nBest: ({best_start},{best_start+best_bs}) +{best_val:.2f}")

    plt.tight_layout()
    plt.savefig("results/qwen3-8b-heatmap.png", dpi=150, bbox_inches="tight")
    print("Heatmap saved to results/qwen3-8b-heatmap.png", flush=True)


if __name__ == "__main__":
    main()