Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
Signed-off-by: ssbuild <[email protected]>
  • Loading branch information
ssbuild committed Oct 11, 2023
1 parent 6e57ff3 commit f8b1c12
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions data_processer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,17 @@ def final(cls,a_ids,b_ids,max_seq_length):
seqlen = len(a_ids)
decoder_seqlen = len(b_ids)

a_ids += [0] * (max_seq_length - len(a_ids))
b_ids += [0] * (max_seq_length - len(b_ids))
labels = copy.deepcopy(a_ids[1:]) + [-100]
pad_len = max_seq_length - seqlen
if pad_len > 0:
a_ids += [0] * pad_len

pad_len = max_seq_length - decoder_seqlen
if pad_len > 0:
b_ids += [0] * pad_len

labels = copy.deepcopy(b_ids[1:]) + [-100]
labels = np.asarray(labels, dtype=np.int64)
labels[decoder_seqlen - 1:] = -100
labels[decoder_seqlen:] = -100

d = {
'input_ids': np.asarray(a_ids, dtype=np.int32),
Expand Down

0 comments on commit f8b1c12

Please sign in to comment.