Skip to content

Commit

Permalink
more lora types
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 26, 2024
1 parent dc4f40e commit 4c9cbe6
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@
Easy Language Model Trainer

## Model Support
* LLaMa, LLaMa2, LLaMa3, GLM, Bloom, OPT, GPT2, GPT Neo, GPT Big Code, Qwen, Baichuan and so on.
* LLaMa, LLaMa2, LLaMa3, GLM, Bloom, OPT, GPT2, GPT Neo, GPT Big Code, Qwen, Baichuan and so on.


## TODO List

* SFT Packing
* DPO
* PPO
8 changes: 7 additions & 1 deletion katheryne/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from typing import List, Optional
from typing import List, Literal, Optional, Union
from transformers import PreTrainedModel
from peft import LoftQConfig, LoraConfig, get_peft_model, PeftModel

Expand All @@ -23,6 +23,9 @@ def setup_lora(
fan_in_fan_out: bool=False,
bias: str="none",
loftq_config: dict=None,
use_rslora: bool=False,
modules_to_save: Optional[List[str]]=None,
init_lora_weights: Union[bool, Literal["gaussian", "pissa", "pissa_niter_[number of iters]", "loftq"]]=True,
use_dora: bool=False,
task_type: str="CAUSAL_LM"
) -> PeftModel:
Expand All @@ -43,6 +46,9 @@ def setup_lora(
fan_in_fan_out=fan_in_fan_out,
bias=bias,
loftq_config=loftq_config,
use_rslora=use_rslora,
modules_to_save=modules_to_save,
init_lora_weights=init_lora_weights,
use_dora=use_dora,
)
model = get_peft_model(base_model, lora_config)
Expand Down
5 changes: 4 additions & 1 deletion katheryne/stages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,11 @@ def train(args: argparse.Namespace, hparams: HParams, create_dataset, lightning_
fan_in_fan_out=hparams.lora.get("fan_in_fan_out", False),
bias=hparams.lora.get("bias", 'none'),
loftq_config=hparams.lora.get("loftq", None),
use_rslora=hparams.lora.get("use_rslora", False),
modules_to_save=hparams.lora.get("modules_to_save", None),
init_lora_weights=hparams.lora.get("init_lora_weights", True),
use_dora=hparams.lora.get("use_dora", False),
task_type=hparams.lora.get("task_type", "CAUSAL_LM")
task_type=hparams.lora.get("task_type", "CAUSAL_LM"),
)
if hparams.get("gradient_checkpointing", False):
model.enable_input_require_grads()
Expand Down
2 changes: 1 addition & 1 deletion katheryne/tools/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def load_local_model(path: str):
# base_model.load_adapter(path)
# base_model.enable_adapters()
# model = base_model
model = PeftModelForCausalLM.from_pretrained(base_model, path, is_trainable=True)
model = PeftModelForCausalLM.from_pretrained(base_model, path, is_trainable=False)
model.eval()
else:
model_config = AutoConfig.from_pretrained(path, trust_remote_code=True)
Expand Down

0 comments on commit 4c9cbe6

Please sign in to comment.