HemanM commited on
Commit
ef5a88b
·
verified ·
1 Parent(s): cb31db0

Update watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +47 -25
watchdog.py CHANGED
@@ -5,6 +5,9 @@ from firebase_admin import credentials, firestore
5
  from evo_model import EvoTransformerForClassification, EvoTransformerConfig
6
  from transformers import BertTokenizer
7
 
 
 
 
8
  # Initialize Firebase if not already initialized
9
  if not firebase_admin._apps:
10
  cred = credentials.Certificate("firebase_key.json")
@@ -44,35 +47,54 @@ def fetch_training_data(tokenizer):
44
  torch.tensor(labels, dtype=torch.long)
45
  )
46
 
47
- def retrain_and_save():
48
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
49
- input_ids, attention_masks, labels = fetch_training_data(tokenizer)
 
 
 
 
 
50
 
51
- if input_ids is None or len(input_ids) < 2:
52
- print("⚠️ Not enough training data.")
53
- return
 
54
 
55
- config = EvoTransformerConfig()
56
- model = EvoTransformerForClassification(config)
57
- model.train()
58
 
59
- optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
60
- loss_fn = torch.nn.CrossEntropyLoss()
 
61
 
62
- for epoch in range(3):
63
- optimizer.zero_grad()
64
- outputs = model(input_ids, attention_mask=attention_masks)
65
- loss = loss_fn(outputs, labels)
66
- loss.backward()
67
- optimizer.step()
68
- print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")
69
 
70
- os.makedirs("trained_model", exist_ok=True)
71
- model.save_pretrained("trained_model")
72
- print("✅ EvoTransformer retrained and saved to trained_model/")
 
 
 
 
73
 
74
- if __name__ == "__main__":
75
- retrain_and_save()
 
 
 
 
 
 
 
76
 
77
- # Alias to match expected import
78
- retrain_model = retrain_and_save
 
 
 
 
 
 
 
 
5
  from evo_model import EvoTransformerForClassification, EvoTransformerConfig
6
  from transformers import BertTokenizer
7
 
8
+ from init_model import load_model
9
+ from dashboard import evolution_accuracy_plot
10
+
11
  # Initialize Firebase if not already initialized
12
  if not firebase_admin._apps:
13
  cred = credentials.Certificate("firebase_key.json")
 
47
  torch.tensor(labels, dtype=torch.long)
48
  )
49
 
50
+ def get_architecture_summary(model):
51
+ summary = {
52
+ "Layers": getattr(model, "num_layers", "N/A"),
53
+ "Attention Heads": getattr(model, "num_heads", "N/A"),
54
+ "FFN Dim": getattr(model, "ffn_dim", "N/A"),
55
+ "Memory Enabled": getattr(model, "use_memory", "N/A"),
56
+ }
57
+ return "\n".join(f"{k}: {v}" for k, v in summary.items())
58
 
59
+ def retrain_model():
60
+ try:
61
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
62
+ input_ids, attention_masks, labels = fetch_training_data(tokenizer)
63
 
64
+ if input_ids is None or len(input_ids) < 2:
65
+ return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."
 
66
 
67
+ config = EvoTransformerConfig()
68
+ model = EvoTransformerForClassification(config)
69
+ model.train()
70
 
71
+ optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
72
+ loss_fn = torch.nn.CrossEntropyLoss()
 
 
 
 
 
73
 
74
+ for epoch in range(3):
75
+ optimizer.zero_grad()
76
+ outputs = model(input_ids, attention_mask=attention_masks)
77
+ loss = loss_fn(outputs, labels)
78
+ loss.backward()
79
+ optimizer.step()
80
+ print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")
81
 
82
+ os.makedirs("trained_model", exist_ok=True)
83
+ model.save_pretrained("trained_model")
84
+
85
+ print("✅ EvoTransformer retrained and saved.")
86
+
87
+ # Reload the updated model for summary + plot
88
+ updated_model = load_model()
89
+ arch_text = get_architecture_summary(updated_model)
90
+ plot = evolution_accuracy_plot()
91
 
92
+ return arch_text, plot, "✅ EvoTransformer retrained successfully!"
93
+
94
+ except Exception as e:
95
+ print(f"❌ Retraining failed: {e}")
96
+ return "❌ Error", None, f"Retrain failed: {e}"
97
+
98
+ # Support direct script run
99
+ if __name__ == "__main__":
100
+ retrain_model()