Skip to content

VITA-Group/Nabla-Reasoner

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

$\nabla$-Reasoner: LLM Reasoning via Test-Time Gradient Descent in Latent Space

The official implementation of the ICLR 2026 paper Nabla-Reasoner: LLM Reasoning via Test-Time Gradient Descent in Latent Space.

Peihao Wang*, Ruisi Cai*, Zhen Wang, Hongyuan Mei, Qiang Liu, Pan Li, Atlas Wang

International Conference on Learning Representations (ICLR), 2026

* denotes equal contribution.

Paper | Code

Get Started

Environment

We tested this release with the environment below (nearby versions should also work).

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
pip install packaging lightning==2.5.0 lightning[app] lightning[data] rich
pip install transformers==4.56.1 tokenizers==0.22.0 datasets==4.0.0 accelerate==1.10.1
pip install jsonargparse[signatures] sentencepiece wandb torchmetrics psutil
pip install tensorboard zstandard pandas pyarrow huggingface_hub
pip install flash-attn==2.8.3
pip install einops opt_einsum
pip install latex2sympy2 word2number pylatexenc
pip install vllm==0.10.2

Data

Evaluation sets are adapted from Spurious_Rewards and hosted at peihaowang/math-reasoning-eval. Please download and place the datasets under directory data.

Usage

Launch vLLM

Start a vLLM server before running $\nabla$-Reasoner with backend="vllm".

python -m vllm.entrypoints.openai.api_server \
  --model <lm_model_path> \
  --dtype bfloat16 \
  --tensor-parallel-size 8 \
  --host 0.0.0.0 \
  --port 8000

Python API Example

The snippet below shows the minimal flow: initialize models, configure optimization hyperparameters, create NablaDecoding, then generate a response.

import torch
from decoding import NablaDecoding

device = "cuda:0"

# Initialize base and rewards models and their tokenizers
# lm_model, lm_tokenizer, rm_model, rm_tokenizer = ...

train_args = {
    "max_iters": 100,
    "warmup_iters_ratio": 0.0,
    "learning_rate": 0.01,
    "min_lr_ratio": 0.1,
    "weight_decay": 0.0,
    "reward_coeff": 1.0,
    "mixed_precision": torch.bfloat16,
    "grad_caching": True,
    "update_postfix": False,
    "embedder_type": "latents",
}

decoder = NablaDecoding(
    lm_model,
    lm_tokenizer,
    rm_model,
    rm_tokenizer,
    train_args,
    device=device,
    max_length=3072,
    verbose=2,
    rejection_sampling=True,
    max_n_generations=8,
    rollout_tau=0.7,
    rollout_top_k=20,
    rollout_top_p=0.8,
    resample_tau=0.5,
    resample_top_k=20,
    resample_top_p=0.8,
    backend="vllm",
    vllm_url="http://127.0.0.1:8000",
    vllm_model_name=lm_model_path,
    confidence_selector_threshold=0.97,
    grad_selector_threshold=8,
)

prompt = "Solve: If 2x + 3 = 11, what is x?"
token_ids = decoder.generate(prompt, return_prompt=False, seed=42)
response = lm_tokenizer.decode(token_ids[0], skip_special_tokens=True)

If your vLLM version does not support token-id based I/O (for example vllm <= 0.7.2), set vllm_output_type="text". You can also run with backend="huggingface" (no vLLM server required).

Single-Prompt CLI

Use run.py for quick test on one prompts.

python run.py \
  --lm_model_name <lm_model_path> \
  --rm_model_name <rm_model_path> \
  --vllm_url http://127.0.0.1:8000 \
  --vllm_model_name <lm_model_path> \
  --prompt "<your prompt>" \
  --embedder_type latents \
  --max_iters 100 \
  --learning_rate 0.01 \
  --reward_coeff 1.0

Parallel Benchmark Run

Use eval/multi_run.py to process an entire benchmark in parallel across multiple workers/GPUs.

python eval/multi_run.py \
  --lm_model_name <lm_model_path> \
  --rm_model_name <rm_model_path> \
  --vllm_url http://127.0.0.1:8000 \
  --vllm_model_name <lm_model_path> \
  --prompts <dataset_name> \
  --num_procs 8 \
  --output_dir <output_dir> \
  --embedder_type latents \
  --n_generations 8 \
  --max_iters 100 \
  --learning_rate 0.01 \
  --reward_coeff 1.0

Evaluation

After generation, compute benchmark metrics from responses.json using:

python eval/eval_outputs.py \
  --json_path <output_dir>/responses.json \
  --output_file <output_dir>/eval.json

Citation

If you find this repository useful, please cite:

@inproceedings{wang2026nabla,
  title={$\nabla$-Reasoner: LLM Reasoning via Test-Time Gradient Descent in Latent Textual Space},
  author={Wang, Peihao and Cai, Ruisi and Wang, Zhen and Mei, Hongyuan and Liu, Qiang and Li, Pan and Wang, Atlas},
  booktitle={International Conference on Learning Representations},
  year={2026}
}

About

[ICLR'26] "Nabla-Reasoner: LLM Reasoning via Test-Time Gradient Descent in Latent Space" by Peihao Wang*, Ruisi Cai*, Zhen Wang, Hongyuan Mei, Qiang Liu, Pan Li, Zhangyang Wang

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors