Created
August 27, 2023 07:29
-
-
Save maxidl/0f76110b8caffa70afbbbe3282e90b28 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import copy | |
import torch | |
import datasets as hfds | |
import transformers | |
from tqdm.auto import tqdm | |
import wandb | |
args = argparse.Namespace() | |
args.seed = 42 | |
args.run_name = "run_0" | |
args.model_name = "meta-llama/Llama-2-7b-hf" | |
args.dataset = "teknium/GPT4-LLM-Cleaned" | |
args.eval_size = 1000 | |
args.dtype=torch.bfloat16 | |
# args.dtype=torch.float32 | |
args.model_max_length=512 | |
args.train_batch_size=8 | |
args.gradient_accumulation_steps=16 | |
args.eval_batch_size=16 | |
args.eval_steps = 50 | |
args.lr = 2e-5 | |
args.num_epochs = 5 | |
args.num_workers = 4 | |
args.device = "cuda" | |
# args.device = "cpu" | |
use_autocast = args.dtype != torch.float32 | |
print(f"use_autocast: {use_autocast}") | |
transformers.set_seed(args.seed) | |
# ======================== setup wandb ====================== | |
run = wandb.init( | |
project="minimal-finetuning", | |
name=args.run_name, | |
config=vars(args) | |
) | |
# ======================== setup dataset ====================== | |
IGNORE_INDEX = -100 | |
PROMPT_FORMAT = ( | |
"""Below is an instruction that describes a task. """ | |
"""Write a response that appropriately completes the request.\n\n""" | |
"""### Instruction:\n{instruction}\n\n### Response:""" | |
) | |
PROMPT_WITH_INPUT_FORMAT = ( | |
"""Below is an instruction that describes a task, paired with an input that provides further context. """ | |
"""Write a response that appropriately completes the request.\n\n""" | |
"""### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:""" | |
) | |
TARGET_FORMAT = """{output}{eos_token}""" | |
def _preprocess_train_example(example, eos_token): | |
if example['input']: | |
prompt_text = PROMPT_WITH_INPUT_FORMAT.format_map(example) | |
else: | |
prompt_text = PROMPT_FORMAT.format_map(example) | |
input_text = prompt_text + TARGET_FORMAT.format_map({**example, "eos_token": eos_token}) | |
return {"prompt_text": prompt_text, "input_text": input_text} | |
ds = hfds.load_dataset(args.dataset)['train'] | |
ds = ds.train_test_split(args.eval_size, seed=args.seed) | |
ds = hfds.DatasetDict({'train': ds['train'], 'eval': ds['test']}) | |
print(ds) | |
tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_name, model_max_length=args.model_max_length, use_fast=False) | |
if not tokenizer.pad_token: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
ds = ds.map(_preprocess_train_example, batched=False, desc="preprocessing", num_proc=args.num_workers, fn_kwargs={"eos_token": tokenizer.eos_token}) | |
# ======================== setup dataloaders====================== | |
def collate_fn(examples): | |
prompt_text_enc = tokenizer( | |
[example["prompt_text"] for example in examples], | |
return_tensors="pt", | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
) | |
input_text_enc = tokenizer( | |
[example["input_text"] for example in examples], | |
return_tensors="pt", | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
) | |
labels = copy.deepcopy(input_text_enc["input_ids"]) | |
for i in range(len(examples)): | |
num_prompt_tokens = prompt_text_enc["input_ids"][i].ne(tokenizer.pad_token_id).sum() | |
labels[i][:num_prompt_tokens] = IGNORE_INDEX # ignore all tokens in the prompt | |
labels[i][input_text_enc["attention_mask"][i] == 0] = IGNORE_INDEX # ignore all pad tokens | |
return {**input_text_enc, "labels": labels} | |
dl_train = torch.utils.data.DataLoader(ds['train'], batch_size=args.train_batch_size, collate_fn=collate_fn, num_workers=args.num_workers, shuffle=True, drop_last=True) | |
dl_eval= torch.utils.data.DataLoader(ds['eval'], batch_size=args.eval_batch_size, collate_fn=collate_fn, num_workers=args.num_workers) | |
# ======================== setup model ====================== | |
model = transformers.LlamaForCausalLM.from_pretrained( | |
args.model_name, | |
torch_dtype=args.dtype, | |
device_map='auto' | |
) | |
if use_autocast: | |
for name, module in model.named_modules(): | |
if ('norm' in name) or ('embed' in name): | |
# print(f'using float32: {name}') | |
module.to(torch.float32) | |
# ======================== setup eval loop ====================== | |
def eval(): | |
with torch.inference_mode(): | |
eval_losses = [] | |
for batch in tqdm(dl_eval, desc="eval_step", leave=False, position=1): | |
batch = {k: v.to(device) for k, v in batch.items()} | |
with torch.cuda.amp.autocast(enabled=use_autocast, dtype=args.dtype): | |
output = model(**batch) | |
loss = output.loss | |
eval_losses.append(loss.item()) | |
eval_loss = torch.tensor(eval_losses).mean().item() | |
return eval_loss | |
# ======================== run train loop ====================== | |
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) | |
device = args.device | |
total_train_steps = (len(dl_train) * args.num_epochs) // args.gradient_accumulation_steps | |
step = 0 | |
train_losses = {} | |
eval_losses = {} | |
loss_accum = 0.0 | |
# do eval before training | |
eval_loss = eval() | |
eval_losses[step] = eval_loss | |
wandb.log({'eval/loss': eval_loss}, step=step) | |
# do training | |
with tqdm(total=total_train_steps, desc="steps", position=0) as pbar: | |
for epoch in range(args.num_epochs): | |
for i, batch in enumerate(dl_train): | |
batch = {k: v.to(device) for k, v in batch.items()} | |
with torch.cuda.amp.autocast(enabled=use_autocast, dtype=args.dtype): | |
output = model(**batch) | |
loss = output.loss | |
loss = loss / args.gradient_accumulation_steps | |
loss.backward() | |
loss_accum += loss.item() | |
if (i + 1) % args.gradient_accumulation_steps == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
step += 1 | |
pbar.update(1) | |
pbar.write(f"step: {step:05d}\ttrain_loss: {loss_accum}") | |
wandb.log({'train/loss': loss_accum}, step=step) | |
train_losses[step] = loss_accum | |
loss_accum = 0.0 | |
if step % args.eval_steps == 0: | |
eval_loss = eval() | |
eval_losses[step] = eval_loss | |
pbar.write(f"step: {step:05d}\teval_loss: {eval_loss}") | |
wandb.log({'eval/loss': eval_loss}, step=step) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment