HemanM commited on
Commit
5ff1d2c
·
verified ·
1 Parent(s): 020bae1

Update retrain.py

Browse files
Files changed (1) hide show
  1. retrain.py +49 -0
retrain.py CHANGED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ from transformers import AutoTokenizer
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from evo_model import EvoTransformerV22
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10
+
11
+ class FeedbackDataset(Dataset):
12
+ def __init__(self, csv_file):
13
+ self.data = pd.read_csv(csv_file).dropna()
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, idx):
19
+ x = self.data.iloc[idx]
20
+ combined = x['query'] + " " + x['context']
21
+ enc = tokenizer(combined, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
22
+ label = torch.tensor(float(x['label'])).unsqueeze(0) # Single logit
23
+ return enc['input_ids'].squeeze(0), label
24
+
25
+ def fine_tune_on_feedback(model_path="trained_model_evo_hellaswag.pt", feedback_file="feedback_log.csv"):
26
+ model = EvoTransformerV22()
27
+ model.load_state_dict(torch.load(model_path))
28
+ model.train()
29
+
30
+ dataset = FeedbackDataset(feedback_file)
31
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
32
+
33
+ model.to("cpu")
34
+ optimizer = optim.Adam(model.parameters(), lr=1e-5)
35
+ loss_fn = nn.BCEWithLogitsLoss()
36
+
37
+ for epoch in range(2): # Light touch-up
38
+ total_loss = 0
39
+ for input_ids, labels in dataloader:
40
+ optimizer.zero_grad()
41
+ outputs = model(input_ids)
42
+ loss = loss_fn(outputs.view(-1), labels.view(-1))
43
+ loss.backward()
44
+ optimizer.step()
45
+ total_loss += loss.item()
46
+ print(f"Epoch {epoch + 1} Loss: {total_loss:.4f}")
47
+
48
+ torch.save(model.state_dict(), model_path)
49
+ print("✅ Evo updated from feedback.")