Skip to content

Commit

Permalink
update reward model
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jun 27, 2024
1 parent 4a66ef6 commit 13a3d90
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 61 deletions.
14 changes: 9 additions & 5 deletions katheryne/light_modules/models/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,17 @@ def forward(self, tokens: Dict[str, torch.Tensor]):
return lm_output

def training_step(self, batch, batch_idx: int):
input_ids, input_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
chosen_input_ids = batch["chosen_input_ids"]
chosen_attention_mask = batch["chosen_attention_mask"]
rejected_input_ids = batch["rejected_input_ids"]
rejected_attention_mask = batch["rejected_attention_mask"]

batch_size = input_ids.shape[0]
batch_size = chosen_input_ids.shape[0]
source_tokens = {
'input_ids': input_ids,
'attention_mask': input_mask,
'labels': labels,
'chosen_input_ids': chosen_input_ids,
'chosen_attention_mask': chosen_attention_mask,
'rejected_input_ids': rejected_input_ids,
'rejected_attention_mask': rejected_attention_mask,
}

lm_output = self.forward(tokens=source_tokens)
Expand Down
127 changes: 75 additions & 52 deletions katheryne/models/reward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutputWithPast

from dataclasses import dataclass
from transformers.utils import ModelOutput

@dataclass
class RewardOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
chosen_mean_scores: Optional[torch.FloatTensor] = None
rejected_mean_scores: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


class KatheryneForRewardModel(nn.Module):
def __init__(self, base_model: PreTrainedModel, pad_token_id: int, num_padding_at_beginning=0, compute_fp32_loss=False):
super().__init__()
Expand Down Expand Up @@ -102,9 +115,10 @@ def forward(
rejected_reward = rejected_rewards[i]

c_inds = (chosen_id == self.pad_token_id).nonzero()
c_ind = c_inds[self.num_padding_at_beginning].item() if len(
c_inds
) > self.num_padding_at_beginning else seq_len # OPT model pads the first token, so we need to use the second padding token as the end of the sequence
if len(c_inds) > self.num_padding_at_beginning:
c_ind = c_inds[self.num_padding_at_beginning].item()
else:
c_ind = seq_len # OPT model pads the first token, so we need to use the second padding token as the end of the sequence
check_divergence = (chosen_id != rejected_id).nonzero()

if len(check_divergence) == 0:
Expand All @@ -114,8 +128,10 @@ def forward(
else:
# Check if there is any padding otherwise take length of sequence
r_inds = (rejected_id == self.pad_token_id).nonzero()
r_ind = r_inds[self.num_padding_at_beginning].item(
) if len(r_inds) > self.num_padding_at_beginning else seq_len
if len(r_inds) > self.num_padding_at_beginning:
r_ind = r_inds[self.num_padding_at_beginning].item()
else:
r_ind = seq_len
end_ind = max(c_ind, r_ind)
divergence_ind = check_divergence[0]
assert divergence_ind > 0
Expand All @@ -133,52 +149,59 @@ def forward(
chosen_mean_scores = torch.stack(chosen_mean_scores)
rejected_mean_scores = torch.stack(rejected_mean_scores)

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
return RewardOutput(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
logits=rewards,
chosen_mean_scores=chosen_mean_scores,
rejected_mean_scores=rejected_mean_scores,
hidden_states=hidden_states,
)

def forward_value(self,
input_ids=None,
attention_mask=None,
past_key_values=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
return_value_only=False,
prompt_length=0,
use_cache=False):

if self.config.model_type == "llama":
kwargs = dict()
else:
kwargs = dict(head_mask=head_mask)

transformer_outputs = self.base_model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**kwargs)
hidden_states = transformer_outputs[0]
values = self.v_head(hidden_states).squeeze(-1)
if return_value_only:
return values
else:
# [0 0 0 0 prompt, answer, 0 0 0 0 ] for step 3, we have padding at the beginning
# [prompt, answer, 0, 0, 0, 0] this is normal
assert prompt_length > 1, "prompt_length must be greater than 1 to help select the end score"
bs = values.size(0)
seq_len = input_ids.shape[1]
chosen_end_scores = [
] # we use this name for consistency with the original forward function
for i in range(bs):
input_id = input_ids[i]
value = values[i]

c_inds = (input_id[prompt_length:] == self.pad_token_id).nonzero()
# here we only use the answer part of the sequence so we do not need to care about the padding at the beginning
c_ind = c_inds[0].item() + prompt_length if len(
c_inds) > 0 else seq_len
chosen_end_scores.append(value[c_ind - 1])
return {
"values": values,
"chosen_end_scores": torch.stack(chosen_end_scores),
}
10 changes: 6 additions & 4 deletions katheryne/tools/chatbot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import torch
import argparse
Expand Down

0 comments on commit 13a3d90

Please sign in to comment.