Skip to content

Commit

Permalink
update reward model impl
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jun 29, 2024
1 parent 13a3d90 commit 7e0fecb
Show file tree
Hide file tree
Showing 12 changed files with 517 additions and 26 deletions.
6 changes: 3 additions & 3 deletions hparams/hparams_reward_qwen1.5_4b.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
"data_path": [
{
"path": "Dahoas/rm-static",
"sample": 1,
"sample": 1.0,
"preprocessor": "rm_static"
}

],
"data_output_path": "./tmp/data_files/",
"model_name_or_path": "/data/wangjun/models/Qwen1.5-4B",
"model_class": "AutoModel",
"model_name_or_path": "Qwen/Qwen1.5-4B",
"model_class": "AutoModelForTokenClassification",
"atten_class": "eager",
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 8,
Expand Down
39 changes: 39 additions & 0 deletions hparams/hparams_reward_seq_qwen1.5_4b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"train_stage": "reward_seq",
"conv_format": "openbuddy",
"end_of_conversation": 151643,
"data_path": [
{
"path": "Dahoas/rm-static",
"sample": 1.0,
"preprocessor": "rm_static"
}

],
"data_output_path": "./tmp/data_files/",
"model_name_or_path": "Qwen/Qwen1.5-4B",
"model_class": "AutoModelForSequenceClassification",
"atten_class": "eager",
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 8,
"accumulate_grad_batches": 64,
"max_seq_len": 2048,
"checkpoint_every_n_train_steps": 100,
"log_every_n_steps": 1,
"val_check_interval": 0.25,
"limit_val_batches": 0.1,
"learning_rate": 4e-5,
"betas": [0.9, 0.95],
"eps": 8e-6,
"lr_decay": 0.999875,
"lr_scheduler_type": "cosine",
"num_warmup_steps": 100,
"max_epochs": 300,
"disable_dropout": true,
"model_torch_dtype": "auto",
"bf16": true,
"gradient_checkpointing": true,
"weight_decay": 0.0,
"gradient_clip_algorithm": "norm",
"gradient_clip_val": 1.0
}
2 changes: 2 additions & 0 deletions katheryne/data/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def create_dataset(dataset: DatasetPath, columns: List[str], preprocessor: Optio
if dataset.sample != 1.0:
sample_size = int(dataset.sample * len(train_dataset))
train_dataset = train_dataset.select(list(range(sample_size)))
else:
sample_size = len(train_dataset)
else:
raise TypeError("Invalid sample number of dataset path object, need int or float.")

Expand Down
Loading

0 comments on commit 7e0fecb

Please sign in to comment.