Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -3,8 +3,26 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
def main():
|
7 |
-
#
|
8 |
auth_token = os.environ.get("HF_TOKEN")
|
9 |
if not auth_token:
|
10 |
raise ValueError("Please set HF_TOKEN environment variable")
|
@@ -39,12 +57,22 @@ def main():
|
|
39 |
input_ids = tokens["input_ids"].to(device)
|
40 |
attention_mask = tokens["attention_mask"].to(device)
|
41 |
|
42 |
-
#
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
print("
|
|
|
48 |
|
49 |
if __name__ == "__main__":
|
50 |
main()
|
|
|
3 |
import torch.nn as nn
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
|
6 |
+
# --- ToyModel placeholder ---
|
7 |
+
# Replace this with your actual ToyModel import if it exists
|
8 |
+
class ToyModel:
|
9 |
+
def __init__(self, model, inputs, dynamic_shapes=None):
|
10 |
+
self.model = model
|
11 |
+
self.inputs = inputs
|
12 |
+
self.dynamic_shapes = dynamic_shapes
|
13 |
+
|
14 |
+
def predict(self, input_ids, attention_mask):
|
15 |
+
with torch.no_grad():
|
16 |
+
logits = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
17 |
+
return logits
|
18 |
+
|
19 |
+
def evaluation(self):
|
20 |
+
print("Evaluation complete.")
|
21 |
+
|
22 |
+
# ----------------------------
|
23 |
+
|
24 |
def main():
|
25 |
+
# Hugging Face token from environment
|
26 |
auth_token = os.environ.get("HF_TOKEN")
|
27 |
if not auth_token:
|
28 |
raise ValueError("Please set HF_TOKEN environment variable")
|
|
|
57 |
input_ids = tokens["input_ids"].to(device)
|
58 |
attention_mask = tokens["attention_mask"].to(device)
|
59 |
|
60 |
+
# Dynamic shapes (optional, can be used by ToyModel)
|
61 |
+
batch_dim = 1
|
62 |
+
seq_dim = input_ids.shape[1]
|
63 |
+
dynamic_shapes = {
|
64 |
+
"input_ids": {0: batch_dim, 1: seq_dim},
|
65 |
+
"attention_mask": {0: batch_dim, 1: seq_dim},
|
66 |
+
}
|
67 |
+
|
68 |
+
# --- ToyModel usage ---
|
69 |
+
ir = ToyModel(model, (input_ids, attention_mask), dynamic_shapes=dynamic_shapes)
|
70 |
+
io_data = ir.predict(input_ids, attention_mask)
|
71 |
+
ir.evaluation()
|
72 |
|
73 |
+
# Output
|
74 |
+
print("Predicted logits shape:", io_data.shape)
|
75 |
+
print("Sample logits:", io_data[0, :5, :5])
|
76 |
|
77 |
if __name__ == "__main__":
|
78 |
main()
|