title | description | icon |
---|---|---|
On-Device Training Quickstart |
Learn how to train models on your device using Simplifine. |
computer |
//TODO - Update with the config schema Note: This workflow is specifically designed for cloud-based training. For on-device training, refer to the On-Device Training Workflow section.
This workflow is optimized for training on your own infrastucture.
from simplifine_alpha import train_engine
import wandb
import os
WandB logging can be passed as the value of the arguement report_to
in the training arguments objects passed to the trainer. It can be fully deactivated as follows:
# Disabling WandB logging, change if you'd like to have it enabled.
wandb.init(mode='disabled')
You can provide the name of a Hugging Face (HF) dataset if you'd like to use a pre-existing dataset. Ensure that you adjust the keys
, response_template
, and template
to match your dataset's structure.
# You can provide an HF dataset name.
# Be sure to change the keys, response template, and template accordingly.
template = '''### TITLE: {title}\n ### ABSTRACT: {abstract}\n ###EXPLANATION: {explanation}'''
response_template='\n ###EXPLANATION:'
keys = ['title', 'abstract', 'explanation']
dataset_name='REPLACE_WITH_HF_DATASET_NAME'
If you don't want to use a Hugging Face dataset, set from_hf to False and provide a local dataset. The example below shows how to manually create a dataset:
from_hf = True
if True: # Change this if you want to use a local dataset instead of a Hugging Face dataset.
from_hf = False
data = {
'title':['title 1', 'title 2', 'title 3']*200,
'abstract':['abstract 1', 'abstract 2', 'abstract 3']*200,
'explanation':['explanation 1', 'explanation 2', 'explanation 3']*200
}
Choose a model for training.
# You can change the model. Bigger models might throw OOM errors.
model_name = 'EleutherAI/pythia-160m'
Finally, use the train_engine.hf_sft function to start training in the cloud. The parameters allow you to control the training process, including whether to use mixed-precision training (fp16
), distributed data parallel (ddp
), and more
train_engine.hf_sft(model_name,
from_hf=from_hf
dataset_name=dataset_name,
keys = keys,
data = data,
template = template,
response_template=response_template,
zero=False,
ddp=False,
gradient_accumulation_steps=4,
fp16=True,
max_seq_length=2048)