Skip to content

Commit

Permalink
chatbot bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jun 21, 2024
1 parent 45aeed3 commit a6258f5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion katheryne/tools/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ def load_local_tokenizer(path: str):
adapter_model_json = os.path.join(path, "adapter_config.json")
if os.path.exists(model_json):
model_json_file = json.load(open(model_json))
model_name = model_json_file["_name_or_path"]
if "_name_or_path" in model_json_file:
model_name = model_json_file["_name_or_path"]
else:
model_name = path

if os.path.exists(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, fast_tokenizer=True)
else:
Expand Down
4 changes: 2 additions & 2 deletions katheryne/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def train(create_dataset, lightning_module_class):
padding="longest",
max_length=hparams.max_seq_len
)
train_dataloader = DataLoader(train_dataset, collate_fn=collator, sampler=train_sampler, num_workers=4, batch_size=hparams.per_device_train_batch_size)
valid_dataloader = DataLoader(valid_dataset, collate_fn=collator, sampler=valid_sampler, num_workers=4, batch_size=hparams.per_device_eval_batch_size)
train_dataloader = DataLoader(train_dataset, collate_fn=collator, sampler=train_sampler, num_workers=hparams.get("train_num_workers", 4), batch_size=hparams.per_device_train_batch_size)
valid_dataloader = DataLoader(valid_dataset, collate_fn=collator, sampler=valid_sampler, num_workers=hparams.get("valid_num_workers", 4), batch_size=hparams.per_device_eval_batch_size)

model = lightning_module_class(
model,
Expand Down

0 comments on commit a6258f5

Please sign in to comment.