Skip to content

Commit

Permalink
rlhf bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 4, 2024
1 parent 82a8ed6 commit ebb1b1e
Show file tree
Hide file tree
Showing 15 changed files with 380 additions and 203 deletions.
15 changes: 8 additions & 7 deletions hparams/hparams_ppo_qwen1.5_4b.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@
"end_of_conversation": 151643,
"data_path": [
{
"path": "Vtuber-plan/sharegpt-cleaned",
"sample": 0.05
"path": "Dahoas/rm-static",
"sample": 1.0,
"preprocessor": "rlhf_rm_static"
}

],
"data_output_path": "./tmp/data_files/",
"model_path": "Qwen/Qwen1.5-4B",
"tokenizer_path": "Qwen/Qwen1.5-4B",
"model_class": "AutoModelForSequenceClassification",
"atten_class": "eager",
"reward_model_path": "Qwen/Qwen1.5-4B",
"reward_model_path": "llm_trainer_reward/lightning_logs/version_0/huggingface_format/checkpoint-step-1100",
"reward_tokenizer_path": "Qwen/Qwen1.5-4B",
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 8,
"accumulate_grad_batches": 64,
"per_device_train_mini_batch_size": 2,
"per_device_train_batch_size": 4,
"per_device_eval_mini_batch_size": 8,
"accumulate_grad_batches": 1,
"max_seq_len": 2048,
"checkpoint_every_n_train_steps": 100,
"log_every_n_steps": 1,
Expand Down
17 changes: 9 additions & 8 deletions hparams/hparams_ppo_qwen1.5_4b_chat.json
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
{
"train_stage": "ppo",
"conv_format": "openbuddy",
"conv_format": "qwen",
"end_of_conversation": 151643,
"data_path": [
{
"path": "Vtuber-plan/sharegpt-cleaned",
"sample": 0.05
"path": "Dahoas/rm-static",
"sample": 1.0,
"preprocessor": "rlhf_rm_static"
}

],
"data_output_path": "./tmp/data_files/",
"model_path": "Qwen/Qwen1.5-4B-chat",
"tokenizer_path": "Qwen/Qwen1.5-4B-chat",
"model_class": "AutoModelForSequenceClassification",
"atten_class": "eager",
"reward_model_path": "Qwen/Qwen1.5-4B-chat",
"reward_model_path": "llm_trainer_reward/lightning_logs/version_0/huggingface_format/checkpoint-step-1100",
"reward_tokenizer_path": "Qwen/Qwen1.5-4B-chat",
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 8,
"accumulate_grad_batches": 64,
"per_device_train_mini_batch_size": 2,
"per_device_train_batch_size": 32,
"per_device_eval_mini_batch_size": 8,
"accumulate_grad_batches": 16,
"max_seq_len": 2048,
"checkpoint_every_n_train_steps": 100,
"log_every_n_steps": 1,
Expand Down
2 changes: 1 addition & 1 deletion hparams/hparams_reward_qwen1.5_4b.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
{
"path": "Dahoas/rm-static",
"sample": 1.0,
"preprocessor": "rm_static"
"preprocessor": "reward_rm_static"
}

],
Expand Down
2 changes: 1 addition & 1 deletion hparams/hparams_reward_seq_qwen1.5_4b.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
{
"path": "Dahoas/rm-static",
"sample": 1.0,
"preprocessor": "rm_static"
"preprocessor": "reward_rm_static"
}

],
Expand Down
39 changes: 39 additions & 0 deletions hparams/hparams_reward_value_qwen1.5_4b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"train_stage": "reward_seq",
"conv_format": "openbuddy",
"end_of_conversation": 151643,
"data_path": [
{
"path": "Dahoas/rm-static",
"sample": 1.0,
"preprocessor": "reward_rm_static"
}

],
"data_output_path": "./tmp/data_files/",
"model_name_or_path": "Qwen/Qwen1.5-4B",
"model_class": "AutoModelForCausalLMWithValueHead",
"atten_class": "eager",
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 8,
"accumulate_grad_batches": 64,
"max_seq_len": 2048,
"checkpoint_every_n_train_steps": 100,
"log_every_n_steps": 1,
"val_check_interval": 0.25,
"limit_val_batches": 0.1,
"learning_rate": 4e-5,
"betas": [0.9, 0.95],
"eps": 8e-6,
"lr_decay": 0.999875,
"lr_scheduler_type": "cosine",
"num_warmup_steps": 100,
"max_epochs": 300,
"disable_dropout": true,
"model_torch_dtype": "auto",
"bf16": true,
"gradient_checkpointing": true,
"weight_decay": 0.0,
"gradient_clip_algorithm": "norm",
"gradient_clip_val": 1.0
}
34 changes: 23 additions & 11 deletions katheryne/data/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,21 @@ def from_data_path(cls, data_path: List[Union[str, "DatasetPath"]]) -> List["Dat

import datasets

def restructure_datasets(dataset: DatasetPath, field: Union[str, List[str]], field_map: Dict[str, str]={}, split:str="train", data_dir=None, data_files=None):
def load_data_split(dataset: DatasetPath, split:str="train", data_dir=None, data_files=None):
dataset_name = dataset.path
raw_datasets = datasets.load_dataset(dataset_name, split=split, data_dir=data_dir, data_files=data_files)
train_dataset = raw_datasets
raw_dataset_split = datasets.load_dataset(dataset_name, split=split, data_dir=data_dir, data_files=data_files)
return raw_dataset_split

def restructure_datasets(raw_dataset, field: Union[str, List[str]], field_map: Dict[str, str]={}):
train_dataset = raw_dataset
cols = train_dataset.column_names

if isinstance(field, str):
cols.remove(field)
else:
for each_field in field:
cols.remove(each_field)
if each_field in cols:
cols.remove(each_field)
train_dataset = train_dataset.remove_columns(cols)

for old_name, new_name in field_map.items():
Expand All @@ -62,19 +66,20 @@ def restructure_datasets(dataset: DatasetPath, field: Union[str, List[str]], fie
def create_dataset(dataset: DatasetPath, columns: List[str], preprocessor: Optional[Union[str, List[str]]]=None, seed=43) -> Tuple[datasets.Dataset, datasets.Dataset]:
dataset_name = dataset.path
raw_datasets = datasets.load_dataset(dataset_name)

if "train" in raw_datasets:
raw_train_dataset = restructure_datasets(dataset, field=columns, split="train")
raw_train_dataset = load_data_split(dataset, split="train")
else:
raw_train_dataset = None

if "validation" in raw_datasets:
raw_validation_dataset = restructure_datasets(dataset, field=columns, split="validation")
raw_validation_dataset = load_data_split(dataset, split="validation")
elif "valid" in raw_datasets:
raw_validation_dataset = restructure_datasets(dataset, field=columns, split="valid")
raw_validation_dataset = load_data_split(dataset, split="valid")
elif "eval" in raw_datasets:
raw_validation_dataset = restructure_datasets(dataset, field=columns, split="eval")
raw_validation_dataset = load_data_split(dataset, split="eval")
elif "evaluation" in raw_datasets:
raw_validation_dataset = restructure_datasets(dataset, field=columns, split="evaluation")
raw_validation_dataset = load_data_split(dataset, split="evaluation")
else:
raw_validation_dataset = None

Expand All @@ -85,7 +90,7 @@ def create_dataset(dataset: DatasetPath, columns: List[str], preprocessor: Optio
else:
train_dataset = raw_train_dataset
eval_dataset = raw_validation_dataset

if preprocessor is not None:
preprocessor_fns = []
if isinstance(preprocessor, str):
Expand All @@ -97,9 +102,16 @@ def create_dataset(dataset: DatasetPath, columns: List[str], preprocessor: Optio

for fn_name in preprocessor_fns:
fn = get_map_fn(fn_name)
if fn is None:
raise Exception(f"Dataset preprocessor `{fn_name}` is not found.")
train_dataset = train_dataset.map(fn)
eval_dataset = eval_dataset.map(fn)


# restructure dataset
train_dataset = restructure_datasets(train_dataset, columns)
eval_dataset = restructure_datasets(eval_dataset, columns)

# Shuffle
if dataset.shuffle:
train_dataset = train_dataset.shuffle(seed=seed)

Expand Down
11 changes: 6 additions & 5 deletions katheryne/data/preprocessors/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
"get_map_fn",
]

from .pretrain.merge_fields import fn_merge_fields
from .reward.rm_static import fn_rm_static

from .pretrain.merge_fields import fn_merge_fields as pretrain_fn_merge_fields
from .reward.rm_static import fn_rm_static as reward_rm_static
from .rlhf.rm_static import fn_rm_static as rlhf_fn_rm_static

DATASET_PREPROCESSOR_FUNCTIONS = {
"rm_static": fn_rm_static,
"merge_fields": fn_merge_fields,
"reward_rm_static": reward_rm_static,
"rlhf_rm_static": rlhf_fn_rm_static,
"pretrain_merge_fields": pretrain_fn_merge_fields,
}


Expand Down
Empty file.
43 changes: 43 additions & 0 deletions katheryne/data/preprocessors/rlhf/rm_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from typing import Any, Dict

def fn_rm_static(sample: Dict[str, Any]) -> Dict[str, Any]:
prompt: str = sample["prompt"]
lines = prompt.splitlines()

messages = []
for line in lines:
if len(line.strip()) == 0:
continue
line = line.strip()

if line.startswith("Human:"):
if len(messages) != 0 and messages[-1]["role"] == "user":
messages[-1]["content"] += "\n" + line
else:
messages.append({
"role": "user",
"content": line.lstrip(),
})
elif line.startswith("Assistant:"):
if len(messages) != 0 and messages[-1]["role"] == "assistant":
messages[-1]["content"] += "\n" + line
else:
messages.append({
"role": "assistant",
"content": line.lstrip(),
})
else:
if len(messages) != 0:
messages[-1]["content"] += "\n" + line
else:
raise Exception("fn_rm_static mapping invalid data.")

sample["messages"] = messages
return sample
14 changes: 11 additions & 3 deletions katheryne/datasets/rlhf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,17 @@ def __getitem__(self, idx):
sample = self.dataset[idx]

messages = sample["messages"]
if len(messages) > 0:
if messages[-1]["role"].lower() == "assistant":
messages[-1]["content"] = None
elif messages[-1]["role"].lower() == "user":
messages.append({
"role": "assistant",
"content": None,
})
prompt, indices = self.get_prompt(messages)
if isinstance(self.end_of_conversation, str):
prompt += self.end_of_conversation
# if isinstance(self.end_of_conversation, str):
# prompt += self.end_of_conversation

encoded_text = self.tokenize(prompt, add_special_tokens=True)

Expand All @@ -56,7 +64,7 @@ def __getitem__(self, idx):
input_ids[-1] = self.tokenizer.eos_token_id
attention_mask = attention_mask[:self.max_seq_len]

input_ids, attention_mask = self.add_end_of_conv(input_ids, attention_mask, self.end_of_conversation)
# input_ids, attention_mask = self.add_end_of_conv(input_ids, attention_mask, self.end_of_conversation)

labels = input_ids.clone()
labels = self.mask_label(prompt, labels, indices)
Expand Down
Loading

0 comments on commit ebb1b1e

Please sign in to comment.