-
Notifications
You must be signed in to change notification settings - Fork 11
/
train.py
138 lines (107 loc) · 4.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import copy
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List
import json
import random
import torch
import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
from transformers import Trainer, AutoConfig, AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaForCausalLM
from dataloaders import TextDataset
from dataloaders import sft_data_collactor, offline_ppo_data_collactor, weighted_sft_data_collactor
from arguments import CustomTrainingArguments
from trainers import SFTWeightedWithKLTrainer, OfflineWeightedPolicyTrainer
from utils import print_rank_0, read_json_or_jsonl_data, set_special_tokens
from utils import DEFAULT_PAD_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_UNK_TOKEN
def get_train_dataset(args):
all_train_data = []
for train_data_path in args.train_data_path:
train_data = read_json_or_jsonl_data(train_data_path)
all_train_data.extend(train_data)
if args.debug_mode:
print_rank_0(f">>> check loaded data:")
print_rank_0(f">>> {all_train_data[0]}")
train_set = TextDataset(all_train_data)
return train_set
def train():
parser = transformers.HfArgumentParser(CustomTrainingArguments)
args = parser.parse_args_into_dataclasses()[0]
print_rank_0(args)
# load data
#---------------------------------------------------------------------------------
train_dataset = get_train_dataset(args)
# setup model
#---------------------------------------------------------------------------------
print_rank_0(f"Begin loading model from {args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
)
if hasattr(model, "ref_model"):
del model.ref_model
if args.train_method in ["SFTwithKL", "OfflinePO"] \
and args.ref_model_name_or_path:
ref_model = AutoModelForCausalLM.from_pretrained(
args.ref_model_name_or_path,
trust_remote_code=True,
)
if hasattr(ref_model, "ref_model"):
del ref_model.ref_model
for param in ref_model.parameters():
param.requires_grad = False
model.ref_model = ref_model
print_rank_0(model)
print_rank_0(f"Finished loading model from {args.model_name_or_path}")
model.is_parallelizable = True
model.model_parallel = True
# setup tokenizer
#---------------------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
model_max_length=args.max_length,
padding_side=args.padding_side,
truncation_side=args.truncation_side,
use_fast=True,
trust_remote_code=True,
)
model, tokenizer = set_special_tokens(model, tokenizer)
# build trainer
#---------------------------------------------------------------------------------
if args.train_method == "SFT":
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=args,
train_dataset=train_dataset,
data_collator=lambda x: sft_data_collactor(args, x, tokenizer)
)
elif args.train_method == "SFTwithKL":
trainer = SFTWeightedWithKLTrainer(
model=model,
tokenizer=tokenizer,
args=args,
train_dataset=train_dataset,
data_collator=lambda x: offline_ppo_data_collactor(args, x, tokenizer)
)
elif args.train_method == "OfflinePO":
trainer = OfflineWeightedPolicyTrainer(
model=model,
tokenizer=tokenizer,
args=args,
train_dataset=train_dataset,
data_collator=lambda x: offline_ppo_data_collactor(args, x , tokenizer)
)
train_result = trainer.train()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if hasattr(trainer.model, "ref_model"):
del trainer.model.ref_model
trainer.save_model(output_dir=args.output_dir)
if __name__ == "__main__":
train()