Skip to content

Commit

Permalink
reward dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jun 24, 2024
1 parent 7a1c1b9 commit d87d9bc
Show file tree
Hide file tree
Showing 14 changed files with 537 additions and 356 deletions.
1 change: 1 addition & 0 deletions hparams/hparams_chat_ningyu_13b.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"train_stage": "chat",
"conv_format": "openbuddy",
"data_path": [
{
Expand Down
8 changes: 6 additions & 2 deletions hparams/hparams_pretrain_llama2_7b_ddp_lora.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"train_stage": "pretrain",
"data_path": [
"bigscience-data/roots_zh-cn_wikipedia"
],
Expand Down Expand Up @@ -30,6 +31,9 @@
},
"lora": {
"r": 128,
"target_modules": ["q_proj", "v_proj"]
}
"lora_alpha": 32,
"lora_dropout": 0.2,
"bias": "all",
"target_modules": "all-linear"
},
}
31 changes: 22 additions & 9 deletions hparams/hparams_pretrain_llama2_7b_deepspeed_lora.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
{
"train_stage": "pretrain",
"data_path": [
"bigscience-data/roots_zh-cn_wikipedia"
],
"data_output_path": "./tmp/data_files/",
"model_name_or_path": "meta-llama/Llama-2-7b-hf",
"per_device_train_batch_size": 2,
"per_device_train_batch_size": 4,
"per_device_eval_batch_size": 4,
"accumulate_grad_batches": 32,
"max_seq_len": 1024,
"accumulate_grad_batches": 128,
"max_seq_len": 2048,
"checkpoint_every_n_train_steps": 1000,
"log_every_n_steps": 1,
"val_check_interval": 0.25,
Expand All @@ -24,13 +25,25 @@
"bf16": true,
"gradient_checkpointing": true,
"weight_decay": 0.0,
"strategy": "deepspeed",
"strategy_params": {
"offload": false,
"zero_stage": 2
},
"lora": {
"r": 128,
"target_modules": ["q_proj", "v_proj"]
"lora_alpha": 32,
"lora_dropout": 0.2,
"bias": "all",
"target_modules": "all-linear"
},
"strategy": "deepspeed",
"strategy_params": {
"zero_stage": 3,
"remote_device": null,
"offload_optimizer": false,
"offload_optimizer_device": "cpu",
"offload_parameters": false,
"offload_params_device": "cpu",
"cpu_checkpointing": true,
"nvme_path": "./nvme_offload",
"params_buffer_count": 5,
"params_buffer_size": 1000000000,
"contiguous_memory_optimization": false
}
}
3 changes: 2 additions & 1 deletion hparams/hparams_pretrain_llama2_7b_fsdp_lora.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"train_stage": "pretrain",
"data_path": [
"dataset/pretrain_mlh"
"bigscience-data/roots_zh-cn_wikipedia"
],
"data_output_path": "./tmp/data_files/",
"model_name_or_path": "meta-llama/Llama-2-7b-hf",
Expand Down
37 changes: 37 additions & 0 deletions hparams/hparams_reward_qwen1.5_4b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"train_stage": "reward",
"conv_format": "openbuddy",
"end_of_conversation": 151643,
"data_path": [
{
"path": "Dahoas/rm-static",
"sample": 1
}

],
"data_output_path": "./tmp/data_files/",
"model_name_or_path": "/data/wangjun/models/Qwen1.5-4B",
"atten_class": "eager",
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 8,
"accumulate_grad_batches": 64,
"max_seq_len": 2048,
"checkpoint_every_n_train_steps": 100,
"log_every_n_steps": 1,
"val_check_interval": 0.25,
"limit_val_batches": 0.1,
"learning_rate": 4e-5,
"betas": [0.9, 0.95],
"eps": 8e-6,
"lr_decay": 0.999875,
"lr_scheduler_type": "cosine",
"num_warmup_steps": 100,
"max_epochs": 300,
"disable_dropout": true,
"model_torch_dtype": "auto",
"bf16": true,
"gradient_checkpointing": true,
"weight_decay": 0.0,
"gradient_clip_algorithm": "norm",
"gradient_clip_val": 1.0
}
13 changes: 7 additions & 6 deletions katheryne/data/loader/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import datasets
from katheryne.data.loader import DatasetPath
from katheryne.datasets.instruction_dataset import InstructionDataset
from katheryne.datasets.reward_dataset import RewardDataset

from katheryne.utils.data.data_utils import get_shuffle_idx, split_dataset
from katheryne.utils.hparams import HParams
Expand Down Expand Up @@ -87,10 +87,11 @@ def create_reward_dataset(hparams: HParams, data_path: List[Union[str, DatasetPa
conv_format = hparams.get("conv_format", "openbuddy")

REWARD_FEATURES = datasets.Features({
"prompt": [{
"history": [{
'role': datasets.Value(dtype='string', id=None),
'content': datasets.Value(dtype='string', id=None)
}],
"prompt": datasets.Value(dtype='string', id=None),
"chosen": datasets.Value(dtype='string', id=None),
"rejected": datasets.Value(dtype='string', id=None),
})
Expand All @@ -100,8 +101,8 @@ def create_reward_dataset(hparams: HParams, data_path: List[Union[str, DatasetPa
for di, d_path in enumerate(data_path_obj):
print(f"Creating dataset: {d_path}")
train_dataset, eval_dataset = create_dataset(d_path.path)
train_dataset = train_dataset.cast(REWARD_FEATURES)
eval_dataset = eval_dataset.cast(REWARD_FEATURES)
# train_dataset = train_dataset.cast(REWARD_FEATURES)
# eval_dataset = eval_dataset.cast(REWARD_FEATURES)

if d_path.shuffle:
train_dataset = train_dataset.shuffle(seed=hparams.get("seed", 43))
Expand All @@ -124,15 +125,15 @@ def create_reward_dataset(hparams: HParams, data_path: List[Union[str, DatasetPa
train_dataset = datasets.concatenate_datasets(train_datasets)
eval_dataset = datasets.concatenate_datasets(eval_datasets)

train_dataset = InstructionDataset(
train_dataset = RewardDataset(
train_dataset,
tokenizer_path,
max_seq_len,
tokenizer.pad_token_id,
conv_format=conv_format,
end_of_conversation=hparams.get("end_of_conversation", None)
)
eval_dataset = InstructionDataset(
eval_dataset = RewardDataset(
eval_dataset,
tokenizer_path,
max_seq_len,
Expand Down
60 changes: 40 additions & 20 deletions katheryne/datasets/reward_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,50 @@ def __getitem__(self, idx):
sample = self.dataset[idx]

messages = sample["messages"]
prompt, indices = self.get_prompt(messages)
chosen = sample["chosen"]
rejected = sample["rejected"]

chosen_messages = messages + {"role": "assistant", "content": chosen}
rejected_messages = messages + {"role": "assistant", "content": rejected}

"""Chosen messages"""
chosen_prompt, chosen_indices = self.get_prompt(chosen_messages)
if isinstance(self.end_of_conversation, str):
prompt += self.end_of_conversation
chosen_prompt += self.end_of_conversation

chosen_encoded_text = self.tokenize(chosen_prompt, add_special_tokens=True)
chosen_input_ids = chosen_encoded_text["input_ids"].squeeze(0)
chosen_attention_mask = chosen_encoded_text["attention_mask"].squeeze(0)

encoded_text = self.tokenize(prompt, add_special_tokens=True)
# if truncation not work
if len(chosen_input_ids) > self.max_seq_len:
chosen_input_ids = chosen_input_ids[:self.max_seq_len]
chosen_input_ids[-1] = self.tokenizer.eos_token_id
chosen_attention_mask = chosen_attention_mask[:self.max_seq_len]

chosen_input_ids, chosen_attention_mask = self.add_end_of_conv(chosen_input_ids, chosen_attention_mask, self.end_of_conversation)

input_ids = encoded_text["input_ids"].squeeze(0)
attention_mask = encoded_text["attention_mask"].squeeze(0)
"""Rejected messages"""
rejected_prompt, rejected_indices = self.get_prompt(rejected_messages)
if isinstance(self.end_of_conversation, str):
rejected_prompt += self.end_of_conversation

rejected_encoded_text = self.tokenize(rejected_prompt, add_special_tokens=True)
rejected_input_ids = rejected_encoded_text["input_ids"].squeeze(0)
rejected_attention_mask = rejected_encoded_text["attention_mask"].squeeze(0)

# if truncation not work
if len(input_ids) > self.max_seq_len:
input_ids = input_ids[:self.max_seq_len]
input_ids[-1] = self.tokenizer.eos_token_id
attention_mask = attention_mask[:self.max_seq_len]

input_ids, attention_mask = self.add_end_of_conv(input_ids, attention_mask, self.end_of_conversation)

labels = input_ids.clone()
labels = self.mask_label(prompt, labels, indices)
# print(len(input_ids), len(labels))
# TODO: labels pad上IGNORE_TOKEN_ID
# labels[:len(encoded_prompt) + 1] = IGNORE_TOKEN_ID # 这里不 + 1抵消bos,是因为可能最后一个token是空格,和回答的第一个token合在一起
if len(rejected_input_ids) > self.max_seq_len:
rejected_input_ids = rejected_input_ids[:self.max_seq_len]
rejected_input_ids[-1] = self.tokenizer.eos_token_id
rejected_attention_mask = rejected_attention_mask[:self.max_seq_len]

rejected_input_ids, rejected_attention_mask = self.add_end_of_conv(rejected_input_ids, rejected_attention_mask, self.end_of_conversation)


return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
"chosen_input_ids": chosen_input_ids,
"chosen_attention_mask": chosen_attention_mask,
"rejected_input_ids": rejected_input_ids,
"rejected_attention_mask": rejected_attention_mask,
}
40 changes: 24 additions & 16 deletions katheryne/models/reward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,38 +34,46 @@ def set_input_embeddings(self, value):

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
chosen_input_ids: Optional[torch.LongTensor] = None,
chosen_attention_mask: Optional[torch.Tensor] = None,
rejected_input_ids: Optional[torch.LongTensor] = None,
rejected_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.model(
if chosen_input_ids is None or rejected_input_ids is None:
raise Exception("The chosen_input_ids and rejected_input_ids shall not be None.")
if chosen_input_ids is not None:
chosen_batch_size = chosen_input_ids.shape[0]
else:
raise Exception("The chosen_input_ids shall not be None.")
if rejected_input_ids is not None:
rejected_batch_size = rejected_input_ids.shape[0]
else:
raise Exception("The rejected_input_ids shall not be None.")
if chosen_batch_size != rejected_batch_size:
raise Exception("The batch size of chosen sentences should equal to that of rejected sentences.")
batch_size = chosen_batch_size

# chosen_input_ids, rejected_input_ids: [batch, seq]
input_ids = torch.cat([chosen_input_ids, rejected_input_ids], dim=1)
attention_mask = torch.cat([chosen_attention_mask, rejected_attention_mask], dim=1)

lm_outputs = self.base_model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
hidden_states = lm_outputs[0]
logits = self.v_head(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
Expand Down
Loading

0 comments on commit d87d9bc

Please sign in to comment.