Spaces:
Running
Running
Update watchdog.py
Browse files- watchdog.py +47 -25
watchdog.py
CHANGED
@@ -5,6 +5,9 @@ from firebase_admin import credentials, firestore
|
|
5 |
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
|
6 |
from transformers import BertTokenizer
|
7 |
|
|
|
|
|
|
|
8 |
# Initialize Firebase if not already initialized
|
9 |
if not firebase_admin._apps:
|
10 |
cred = credentials.Certificate("firebase_key.json")
|
@@ -44,35 +47,54 @@ def fetch_training_data(tokenizer):
|
|
44 |
torch.tensor(labels, dtype=torch.long)
|
45 |
)
|
46 |
|
47 |
-
def
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
model.train()
|
58 |
|
59 |
-
|
60 |
-
|
|
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
outputs = model(input_ids, attention_mask=attention_masks)
|
65 |
-
loss = loss_fn(outputs, labels)
|
66 |
-
loss.backward()
|
67 |
-
optimizer.step()
|
68 |
-
print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
|
6 |
from transformers import BertTokenizer
|
7 |
|
8 |
+
from init_model import load_model
|
9 |
+
from dashboard import evolution_accuracy_plot
|
10 |
+
|
11 |
# Initialize Firebase if not already initialized
|
12 |
if not firebase_admin._apps:
|
13 |
cred = credentials.Certificate("firebase_key.json")
|
|
|
47 |
torch.tensor(labels, dtype=torch.long)
|
48 |
)
|
49 |
|
50 |
+
def get_architecture_summary(model):
|
51 |
+
summary = {
|
52 |
+
"Layers": getattr(model, "num_layers", "N/A"),
|
53 |
+
"Attention Heads": getattr(model, "num_heads", "N/A"),
|
54 |
+
"FFN Dim": getattr(model, "ffn_dim", "N/A"),
|
55 |
+
"Memory Enabled": getattr(model, "use_memory", "N/A"),
|
56 |
+
}
|
57 |
+
return "\n".join(f"{k}: {v}" for k, v in summary.items())
|
58 |
|
59 |
+
def retrain_model():
|
60 |
+
try:
|
61 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
62 |
+
input_ids, attention_masks, labels = fetch_training_data(tokenizer)
|
63 |
|
64 |
+
if input_ids is None or len(input_ids) < 2:
|
65 |
+
return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."
|
|
|
66 |
|
67 |
+
config = EvoTransformerConfig()
|
68 |
+
model = EvoTransformerForClassification(config)
|
69 |
+
model.train()
|
70 |
|
71 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
72 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
for epoch in range(3):
|
75 |
+
optimizer.zero_grad()
|
76 |
+
outputs = model(input_ids, attention_mask=attention_masks)
|
77 |
+
loss = loss_fn(outputs, labels)
|
78 |
+
loss.backward()
|
79 |
+
optimizer.step()
|
80 |
+
print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")
|
81 |
|
82 |
+
os.makedirs("trained_model", exist_ok=True)
|
83 |
+
model.save_pretrained("trained_model")
|
84 |
+
|
85 |
+
print("✅ EvoTransformer retrained and saved.")
|
86 |
+
|
87 |
+
# Reload the updated model for summary + plot
|
88 |
+
updated_model = load_model()
|
89 |
+
arch_text = get_architecture_summary(updated_model)
|
90 |
+
plot = evolution_accuracy_plot()
|
91 |
|
92 |
+
return arch_text, plot, "✅ EvoTransformer retrained successfully!"
|
93 |
+
|
94 |
+
except Exception as e:
|
95 |
+
print(f"❌ Retraining failed: {e}")
|
96 |
+
return "❌ Error", None, f"Retrain failed: {e}"
|
97 |
+
|
98 |
+
# Support direct script run
|
99 |
+
if __name__ == "__main__":
|
100 |
+
retrain_model()
|