Create train.py
Browse files
train.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train.py — EvoDecoderModel training loop
|
2 |
+
import torch
|
3 |
+
from torch import nn, optim
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from evo_model import EvoDecoderModel
|
7 |
+
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
10 |
+
|
11 |
+
class TextDataset(Dataset):
|
12 |
+
def __init__(self, texts, tokenizer, max_len=512):
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
self.inputs = [tokenizer.encode(t, truncation=True, max_length=max_len, padding='max_length') for t in texts]
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.inputs)
|
18 |
+
|
19 |
+
def __getitem__(self, idx):
|
20 |
+
x = torch.tensor(self.inputs[idx][:-1])
|
21 |
+
y = torch.tensor(self.inputs[idx][1:])
|
22 |
+
return x, y
|
23 |
+
|
24 |
+
# Example data (replace with your own)
|
25 |
+
texts = [
|
26 |
+
"User: How are you?\nAssistant: I'm doing well, thank you.",
|
27 |
+
"User: What is AI?\nAssistant: AI stands for artificial intelligence.",
|
28 |
+
# Add more...
|
29 |
+
]
|
30 |
+
dataset = TextDataset(texts, tokenizer)
|
31 |
+
loader = DataLoader(dataset, batch_size=2, shuffle=True)
|
32 |
+
|
33 |
+
# Initialize model
|
34 |
+
model = EvoDecoderModel(vocab_size=tokenizer.vocab_size, d_model=512).to(device)
|
35 |
+
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
|
36 |
+
criterion = nn.CrossEntropyLoss()
|
37 |
+
|
38 |
+
# Training loop
|
39 |
+
epochs = 5
|
40 |
+
for epoch in range(epochs):
|
41 |
+
total_loss = 0
|
42 |
+
model.train()
|
43 |
+
for x, y in loader:
|
44 |
+
x, y = x.to(device), y.to(device)
|
45 |
+
optimizer.zero_grad()
|
46 |
+
logits = model(x)
|
47 |
+
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
|
48 |
+
loss.backward()
|
49 |
+
optimizer.step()
|
50 |
+
total_loss += loss.item()
|
51 |
+
print(f"Epoch {epoch+1} Loss: {total_loss/len(loader):.4f}")
|
52 |
+
|
53 |
+
torch.save(model.state_dict(), "evo_decoder_model.pt")
|
54 |
+
print("✅ Model saved to evo_decoder_model.pt")
|