Skip to content

Commit

Permalink
lora model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 8, 2024
1 parent ebb1b1e commit dc4f40e
Show file tree
Hide file tree
Showing 10 changed files with 295 additions and 236 deletions.
84 changes: 84 additions & 0 deletions katheryne/light_modules/trainer_rlhf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

import trl
import torch
import tqdm
from transformers import PreTrainedModel
from katheryne.utils.hparams import HParams


class TrainerRLHF(object):
def __init__(self, hparams: HParams, trainer: trl.trainer.BaseTrainer,
model: PreTrainedModel, tokenizer,
reward_model: PreTrainedModel, reward_tokenizer,
ref_model: PreTrainedModel, ref_tokenizer) -> None:
self.hparams = hparams
self.trainer = trainer

self.model, self.tokenizer = model, tokenizer
self.reward_model, self.reward_tokenizer = reward_model, reward_tokenizer
self.ref_model, self.ref_tokenizer = ref_model, ref_tokenizer

def train(self) -> None:
# Move Reward Model to CUDA
device = self.trainer.accelerator.device
if self.trainer.accelerator.num_processes == 1:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
self.reward_model.to(device)

# output_length_sampler = LengthSampler(hparams.get("output_min_length", 16), hparams.get("output_max_length", 1024))

generation_kwargs = {
"num_beams": 1,
"do_sample": False,
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": 2048,
}

max_epochs = self.hparams.get("max_epochs", 999)
max_steps = self.hparams.get("max_steps", -1)
epoch_steps = len(self.trainer.dataloader)

dataiter = iter(self.trainer.dataloader)

for epoch in range(max_epochs):
for step in enumerate(tqdm.tqdm(range(epoch_steps))):
try:
batch = next(dataiter)
except StopIteration:
dataiter = iter(self.trainer.dataloader)
batch = next(dataiter)

# dict_keys(['input_ids', 'attention_mask', 'labels', 'response'])
query_tensor_input_ids = batch["input_ids"]
query_tensors = [query_tensor for query_tensor in query_tensor_input_ids]

response_tensors = self.trainer.generate(query_tensor=query_tensors, batch_size=2, return_prompt=True, **generation_kwargs)

batch["response"] = [self.tokenizer.decode(r.squeeze()) for r in response_tensors]
for i in range(len(query_tensors)):
print(self.tokenizer.decode(query_tensors[i].squeeze()))
print("--------------")
print(self.tokenizer.decode(response_tensors[i].squeeze()))
print("===========")

# Compute reward score
encoded_texts = self.reward_tokenizer(batch["response"],
padding="longest",
truncation=True,
return_tensors="pt",
add_special_tokens=True,
).to(self.reward_model.device)

rewards = self.reward_model.forward(
input_ids=encoded_texts["input_ids"],
attention_mask=encoded_texts["attention_mask"],
)
score_tensor = rewards.logits
scores = [s.item() for s in score_tensor]
# Run PPO step
stats = self.trainer.step(query_tensors, response_tensors, scores)
self.trainer.log_stats(stats, batch, scores)

# Save Checkpoints
# TODO: ....
27 changes: 27 additions & 0 deletions katheryne/stages/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from katheryne.stages.base import parse_args
from katheryne.stages.rlhf_base import rlhf_train
from katheryne.utils.hparams import HParams

def dpo():
args = parse_args()
hparams = HParams.from_json_file(args.hparams)
train_stage = hparams.get("train_stage", None)
if train_stage is None:
raise Exception("Please specify the train stage in the hparam file.")

if train_stage in ["dpo"]:
from katheryne.data.loader.rlhf import create_rlhf_dataset
from trl import DPOTrainer, DPOConfig
rlhf_train(args, hparams, create_rlhf_dataset, DPOConfig, DPOTrainer)
else:
raise Exception("The train stage is not consistent with the stage in config.")

if __name__ == "__main__":
dpo()
Empty file removed katheryne/stages/kpo.py
Empty file.
27 changes: 27 additions & 0 deletions katheryne/stages/kto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from katheryne.stages.base import parse_args
from katheryne.stages.rlhf_base import rlhf_train
from katheryne.utils.hparams import HParams

def kto():
args = parse_args()
hparams = HParams.from_json_file(args.hparams)
train_stage = hparams.get("train_stage", None)
if train_stage is None:
raise Exception("Please specify the train stage in the hparam file.")

if train_stage in ["kto"]:
from katheryne.data.loader.rlhf import create_rlhf_dataset
from trl import KTOTrainer, KTOConfig
rlhf_train(args, hparams, create_rlhf_dataset, KTOConfig, KTOTrainer)
else:
raise Exception("The train stage is not consistent with the stage in config.")

if __name__ == "__main__":
kto()
11 changes: 6 additions & 5 deletions katheryne/stages/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from katheryne.stages.base import train, parse_args
from katheryne.stages.base import parse_args
from katheryne.stages.rlhf_base import rlhf_train
from katheryne.utils.hparams import HParams

def ppo():
Expand All @@ -16,11 +17,11 @@ def ppo():
raise Exception("Please specify the train stage in the hparam file.")

if train_stage in ["ppo"]:
from katheryne.light_modules.models.pretrain_model import PretrainLanguageModel
from katheryne.data.loader.pretrain import create_pretrain_dataset
train(args, hparams, create_pretrain_dataset, PretrainLanguageModel)
from katheryne.data.loader.rlhf import create_rlhf_dataset
from trl import PPOTrainer, PPOConfig
rlhf_train(args, hparams, create_rlhf_dataset, PPOConfig, PPOTrainer)
else:
raise Exception("The train stage is not consistent with the stage in config.")

if __name__ == "__main__":
ppo()
ppo()
115 changes: 27 additions & 88 deletions katheryne/stages/rlhf_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
Expand Down Expand Up @@ -28,13 +28,14 @@
)

from katheryne.data.collators import DataCollatorWithPadding
from katheryne.light_modules.trainer_rlhf import TrainerRLHF
from katheryne.models.adapters import setup_lora
from katheryne.utils.hparams import HParams
from katheryne.utils.model.model_utils import create_hf_model
from katheryne.utils.model.tokenizer_utils import load_hf_tokenizer
from katheryne.utils.utils import parse_dtype_str

def rlhf_train(args: argparse.Namespace, hparams: HParams, create_dataset):
def rlhf_train(args: argparse.Namespace, hparams: HParams, create_dataset, config_class, trainer_class):
torch.autograd.set_detect_anomaly(True)
master_port = os.environ.get("MASTER_PORT", None)
master_addr = os.environ.get("MASTER_ADDR", None)
Expand Down Expand Up @@ -69,12 +70,12 @@ def rlhf_train(args: argparse.Namespace, hparams: HParams, create_dataset):
torch_dtype = parse_dtype_str(torch_dtype_str)
else:
torch_dtype = torch_dtype_str

if torch_dtype == torch.bfloat16 and args.accelerator in ["cpu"]:
raise RuntimeError("Models in bfloat16 cannot run with the accelerator CPU.")
if torch_dtype == torch.float16 and args.accelerator in ["cpu"]:
raise RuntimeError("Models in float16 cannot run with the accelerator CPU.")

# Create Model
model_class_config = hparams.get("model_class", "AutoModelForCausalLM")
if model_class_config == "AutoModelForCausalLM":
Expand Down Expand Up @@ -129,7 +130,7 @@ def rlhf_train(args: argparse.Namespace, hparams: HParams, create_dataset):
reward_tokenizer_path = hparams.get("reward_tokenizer_path", reward_model_path)
else:
reward_tokenizer_path = reward_model

# Load Reward Tokenizer
reward_tokenizer = load_hf_tokenizer(reward_tokenizer_path, fast_tokenizer=True, padding_side="left")
reward_tokenizer.pad_token = reward_tokenizer.eos_token
Expand Down Expand Up @@ -161,7 +162,7 @@ def rlhf_train(args: argparse.Namespace, hparams: HParams, create_dataset):
padding="longest",
max_length=hparams.max_seq_len
)

# Create Config
config_params = {}
config_params["tracker_kwargs"] = {}
Expand Down Expand Up @@ -195,7 +196,7 @@ def rlhf_train(args: argparse.Namespace, hparams: HParams, create_dataset):
raise Exception("Unsupported logger type.")
config_params["project_kwargs"]["logging_dir"] = logger_save_dir
break

config_params["accelerator_kwargs"]["device_placement"] = True
if "fp16" in hparams and hparams.fp16:
print("using fp16")
Expand Down Expand Up @@ -227,86 +228,24 @@ def rlhf_train(args: argparse.Namespace, hparams: HParams, create_dataset):
config_params["mini_batch_size"] = hparams.get("per_device_train_mini_batch_size", 128)
config_params["batch_size"] = hparams.get("per_device_train_batch_size", 128)

config = PPOConfig(**config_params)

ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=train_dataset, data_collator=collator)

# Move Reward Model to CUDA
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
reward_model.to(device)

# output_length_sampler = LengthSampler(hparams.get("output_min_length", 16), hparams.get("output_max_length", 1024))

generation_kwargs = {
"num_beams": 1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": False,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"max_length": 8192,
"max_new_tokens": 1024,
}

generation_kwargs = {
"num_beams": 1,
"do_sample": False,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"max_new_tokens": 2048,
}

max_epochs = hparams.get("max_epochs", 999)
max_steps = hparams.get("max_steps", -1)
epoch_steps = len(ppo_trainer.dataloader)

dataiter = iter(ppo_trainer.dataloader)

for epoch in range(max_epochs):
for step in enumerate(tqdm.tqdm(range(epoch_steps))):
try:
batch = next(dataiter)
except StopIteration:
dataiter = iter(ppo_trainer.dataloader)
batch = next(dataiter)

# dict_keys(['input_ids', 'attention_mask', 'labels', 'response'])
query_tensor_input_ids = batch["input_ids"]
query_tensors = [query_tensor for query_tensor in query_tensor_input_ids]

response_tensors = ppo_trainer.generate(query_tensor=query_tensors, batch_size=2, return_prompt=True, **generation_kwargs)

batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
for i in range(len(query_tensors)):
print(tokenizer.decode(query_tensors[i].squeeze()))
print("--------------")
print(tokenizer.decode(response_tensors[i].squeeze()))
print("===========")

# Compute reward score
encoded_texts = reward_tokenizer(batch["response"],
padding="longest",
truncation=True,
return_tensors="pt",
add_special_tokens=True,
).to(reward_model.device)

rewards = reward_model.forward(
input_ids=encoded_texts["input_ids"],
attention_mask=encoded_texts["attention_mask"],
)
score_tensor = rewards.logits
scores = [s.item() for s in score_tensor]
# Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, scores)
ppo_trainer.log_stats(stats, batch, scores)

# Save Checkpoints
# TODO: ....



config = config_class(**config_params)

trainer = trainer_class(
config,
model,
ref_model,
tokenizer,
dataset=train_dataset,
data_collator=collator,
)

TrainerRLHF(
hparams=hparams,
trainer=trainer,
model=model,
tokenizer=tokenizer,
reward_model=reward_model,
reward_tokenizer=reward_tokenizer,
ref_model=ref_model,
ref_tokenizer=tokenizer,
)
11 changes: 10 additions & 1 deletion katheryne/stages/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,16 @@ def auto_train_stage():
train(args, hparams, create_reward_dataset, RewardLanguageModel)
elif train_stage in ["ppo"]:
from katheryne.data.loader.rlhf import create_rlhf_dataset
rlhf_train(args, hparams, create_rlhf_dataset)
from trl import PPOTrainer, PPOConfig
rlhf_train(args, hparams, create_rlhf_dataset, PPOConfig, PPOTrainer)
elif train_stage in ["dpo"]:
from katheryne.data.loader.rlhf import create_rlhf_dataset
from trl import DPOTrainer, DPOConfig
rlhf_train(args, hparams, create_rlhf_dataset, DPOConfig, DPOTrainer)
elif train_stage in ["kto"]:
from katheryne.data.loader.rlhf import create_rlhf_dataset
from trl import KTOTrainer, KTOConfig
rlhf_train(args, hparams, create_rlhf_dataset, KTOConfig, KTOTrainer)
else:
raise NotImplementedError("The train stage has not been implemented.")

Expand Down
Loading

0 comments on commit dc4f40e

Please sign in to comment.