roll-ai's picture
Upload 333 files
e8bdafd verified
raw
history blame contribute delete
432 Bytes
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from finetune.models.utils import get_model_cls
from finetune.schemas import Args
import torch
torch.backends.cuda.matmul.allow_tf32 = True
def main():
args = Args.parse_args()
trainer_cls = get_model_cls(args.model_name, args.training_type)
trainer = trainer_cls(args)
trainer.fit()
if __name__ == "__main__":
main()