Skip to content

Commit

Permalink
concat datasets bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jan 2, 2024
1 parent 3b73d13 commit 750375d
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion hparams/hparams_chat_llama2_13b_lora.json
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
},
{
"path": "../../datasets/HanChat/Summary/",
"sample": 1.0
"sample": 0.1
},
{
"path": "../../datasets/HanChat/TextCorrection/",
Expand Down
4 changes: 2 additions & 2 deletions katheryne/data/loader/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ def create_chat_dataset(hparams: HParams, data_path: List[Union[str, DatasetPath

if isinstance(d_path.sample, int):
sample_size = d_path.sample
train_dataset = train_dataset[:sample_size]
train_dataset = train_dataset.select(list(range(sample_size)))
elif isinstance(d_path.sample, float):
if d_path.sample != 1.0:
sample_size = int(d_path.sample * len(train_dataset))
train_dataset = train_dataset[:sample_size]
train_dataset = train_dataset.select(list(range(sample_size)))
else:
raise TypeError("Invalid sample number of dataset path object, need int or float.")

Expand Down
4 changes: 2 additions & 2 deletions katheryne/data/loader/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ def create_pretrain_dataset(hparams: HParams, data_path: List[Union[str, Dataset

if isinstance(d_path.sample, int):
sample_size = d_path.sample
train_dataset = train_dataset[:sample_size]
train_dataset = train_dataset.select(list(range(sample_size)))
elif isinstance(d_path.sample, float):
if d_path.sample != 1.0:
sample_size = int(d_path.sample * len(train_dataset))
train_dataset = train_dataset[:sample_size]
train_dataset = train_dataset.select(list(range(sample_size)))
else:
raise TypeError("Invalid sample number of dataset path object, need int or float.")

Expand Down
2 changes: 1 addition & 1 deletion katheryne/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def train(create_dataset, lightning_module_class):
model.gradient_checkpointing_enable()

# Save Model
save_hf_format(model, tokenizer, "./lightning_logs/huggingface_format", sub_folder=f"checkpoint-step-0")
# save_hf_format(model, tokenizer, "./lightning_logs/huggingface_format", sub_folder=f"checkpoint-step-0")

# Prepare the data
print("***** Prepare Dataset *****")
Expand Down

0 comments on commit 750375d

Please sign in to comment.