HemanM commited on
Commit
6a5863f
·
verified ·
1 Parent(s): 738a56e

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +54 -0
train.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py — EvoDecoderModel training loop
2
+ import torch
3
+ from torch import nn, optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from transformers import AutoTokenizer
6
+ from evo_model import EvoDecoderModel
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10
+
11
+ class TextDataset(Dataset):
12
+ def __init__(self, texts, tokenizer, max_len=512):
13
+ self.tokenizer = tokenizer
14
+ self.inputs = [tokenizer.encode(t, truncation=True, max_length=max_len, padding='max_length') for t in texts]
15
+
16
+ def __len__(self):
17
+ return len(self.inputs)
18
+
19
+ def __getitem__(self, idx):
20
+ x = torch.tensor(self.inputs[idx][:-1])
21
+ y = torch.tensor(self.inputs[idx][1:])
22
+ return x, y
23
+
24
+ # Example data (replace with your own)
25
+ texts = [
26
+ "User: How are you?\nAssistant: I'm doing well, thank you.",
27
+ "User: What is AI?\nAssistant: AI stands for artificial intelligence.",
28
+ # Add more...
29
+ ]
30
+ dataset = TextDataset(texts, tokenizer)
31
+ loader = DataLoader(dataset, batch_size=2, shuffle=True)
32
+
33
+ # Initialize model
34
+ model = EvoDecoderModel(vocab_size=tokenizer.vocab_size, d_model=512).to(device)
35
+ optimizer = optim.AdamW(model.parameters(), lr=5e-5)
36
+ criterion = nn.CrossEntropyLoss()
37
+
38
+ # Training loop
39
+ epochs = 5
40
+ for epoch in range(epochs):
41
+ total_loss = 0
42
+ model.train()
43
+ for x, y in loader:
44
+ x, y = x.to(device), y.to(device)
45
+ optimizer.zero_grad()
46
+ logits = model(x)
47
+ loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
48
+ loss.backward()
49
+ optimizer.step()
50
+ total_loss += loss.item()
51
+ print(f"Epoch {epoch+1} Loss: {total_loss/len(loader):.4f}")
52
+
53
+ torch.save(model.state_dict(), "evo_decoder_model.pt")
54
+ print("✅ Model saved to evo_decoder_model.pt")