Estherrr777 commited on
Commit
8d218d1
·
verified ·
1 Parent(s): bdbb3ae

Create app/train.py

Browse files
Files changed (1) hide show
  1. backend/app/train.py +79 -0
backend/app/train.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ TrainingArguments,
7
+ Trainer,
8
+ DataCollatorForLanguageModeling,
9
+ )
10
+ import torch
11
+ from datasets import Dataset
12
+
13
+ # -------- Settings --------
14
+ MODEL_NAME = "google/gemma-1.1-2b-it"
15
+ DATA_PATH = "./backend/data/pregnancy_dataset.json"
16
+ SAVE_PATH = "./backend/app/checkpoints"
17
+
18
+ # -------- Load Dataset --------
19
+ def load_dataset():
20
+ with open(DATA_PATH, "r") as f:
21
+ data = json.load(f)
22
+ dataset = Dataset.from_list(data)
23
+ return dataset
24
+
25
+ # -------- Tokenization --------
26
+ def tokenize_function(example, tokenizer):
27
+ return tokenizer(
28
+ f"<start_of_turn> {example['prompt']} <end_of_turn>\n{example['completion']} <end_of_turn>",
29
+ truncation=True,
30
+ padding="max_length",
31
+ max_length=256,
32
+ )
33
+
34
+ # -------- Main Training Function --------
35
+ def train():
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
37
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
38
+
39
+ # Optional: Freeze all except final layers
40
+ for param in model.base_model.parameters():
41
+ param.requires_grad = False
42
+ for param in model.lm_head.parameters():
43
+ param.requires_grad = True
44
+
45
+ # Load and tokenize
46
+ raw_dataset = load_dataset()
47
+ tokenized_dataset = raw_dataset.map(lambda x: tokenize_function(x, tokenizer))
48
+
49
+ # Data collator
50
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
51
+
52
+ # Training arguments
53
+ training_args = TrainingArguments(
54
+ output_dir=SAVE_PATH,
55
+ num_train_epochs=3,
56
+ per_device_train_batch_size=2,
57
+ save_steps=50,
58
+ logging_steps=10,
59
+ save_total_limit=1,
60
+ remove_unused_columns=False,
61
+ report_to="none",
62
+ )
63
+
64
+ # Trainer
65
+ trainer = Trainer(
66
+ model=model,
67
+ args=training_args,
68
+ train_dataset=tokenized_dataset,
69
+ tokenizer=tokenizer,
70
+ data_collator=data_collator,
71
+ )
72
+
73
+ trainer.train()
74
+ trainer.save_model(SAVE_PATH)
75
+ tokenizer.save_pretrained(SAVE_PATH)
76
+ print("✅ Fine-tuned model saved!")
77
+
78
+ if __name__ == "__main__":
79
+ train()