HemanM commited on
Commit
4efc6bf
·
verified ·
1 Parent(s): d9fdd26

Update init_model.py

Browse files
Files changed (1) hide show
  1. init_model.py +11 -32
init_model.py CHANGED
@@ -1,39 +1,18 @@
1
  import torch
2
- from evo_model import EvoTransformerForClassification, EvoTransformerConfig
3
- from transformers import AutoTokenizer
4
- from torch.utils.data import DataLoader, TensorDataset
5
  import os
 
6
 
7
- def retrain_model():
8
- print("🔄 Starting Evo retrain...")
 
 
 
9
 
10
- # Sample retraining data
11
- examples = [
12
- "Goal: House on fire. Option 1: Exit house. Option 2: Stay in house.",
13
- "Goal: Wet floor. Option 1: Walk slowly. Option 2: Run fast.",
14
- "Goal: Loud music. Option 1: Turn it down. Option 2: Ignore it."
15
- ]
16
- labels = [0, 0, 0] # Option 1 is correct
17
-
18
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
19
- model = EvoTransformerForClassification(EvoTransformerConfig())
20
-
21
- inputs = tokenizer(examples, padding=True, truncation=True, return_tensors="pt")
22
- labels_tensor = torch.tensor(labels)
23
-
24
- dataset = TensorDataset(inputs["input_ids"], inputs["attention_mask"], labels_tensor)
25
- dataloader = DataLoader(dataset, batch_size=2)
26
-
27
- model.train()
28
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
29
-
30
- for epoch in range(2):
31
- for input_ids, attention_mask, labels_batch in dataloader:
32
- optimizer.zero_grad()
33
- loss, _ = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels_batch)
34
- loss.backward()
35
- optimizer.step()
36
 
 
37
  os.makedirs("trained_model", exist_ok=True)
38
  model.save_pretrained("trained_model")
39
- print("✅ Evo retrained and saved.")
 
 
1
  import torch
 
 
 
2
  import os
3
+ from evo_model import EvoTransformerConfig, EvoTransformerForClassification
4
 
5
+ def initialize_evo_model():
6
+ print("⚙️ Reinitializing EvoTransformer model...")
7
+
8
+ # Create default config
9
+ config = EvoTransformerConfig()
10
 
11
+ # Create model
12
+ model = EvoTransformerForClassification(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Save model to disk
15
  os.makedirs("trained_model", exist_ok=True)
16
  model.save_pretrained("trained_model")
17
+
18
+ print("✅ EvoTransformer initial model saved to 'trained_model/'")