rahul7star commited on
Commit
4593239
·
verified ·
1 Parent(s): 8145ee9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -6
app.py CHANGED
@@ -3,8 +3,26 @@ import torch
3
  import torch.nn as nn
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def main():
7
- # Get Hugging Face token from env
8
  auth_token = os.environ.get("HF_TOKEN")
9
  if not auth_token:
10
  raise ValueError("Please set HF_TOKEN environment variable")
@@ -39,12 +57,22 @@ def main():
39
  input_ids = tokens["input_ids"].to(device)
40
  attention_mask = tokens["attention_mask"].to(device)
41
 
42
- # Forward pass
43
- with torch.no_grad():
44
- logits = model(input_ids, attention_mask)
 
 
 
 
 
 
 
 
 
45
 
46
- print("Logits shape:", logits.shape)
47
- print("Sample logits:", logits[0, :5, :5])
 
48
 
49
  if __name__ == "__main__":
50
  main()
 
3
  import torch.nn as nn
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ # --- ToyModel placeholder ---
7
+ # Replace this with your actual ToyModel import if it exists
8
+ class ToyModel:
9
+ def __init__(self, model, inputs, dynamic_shapes=None):
10
+ self.model = model
11
+ self.inputs = inputs
12
+ self.dynamic_shapes = dynamic_shapes
13
+
14
+ def predict(self, input_ids, attention_mask):
15
+ with torch.no_grad():
16
+ logits = self.model(input_ids=input_ids, attention_mask=attention_mask)
17
+ return logits
18
+
19
+ def evaluation(self):
20
+ print("Evaluation complete.")
21
+
22
+ # ----------------------------
23
+
24
  def main():
25
+ # Hugging Face token from environment
26
  auth_token = os.environ.get("HF_TOKEN")
27
  if not auth_token:
28
  raise ValueError("Please set HF_TOKEN environment variable")
 
57
  input_ids = tokens["input_ids"].to(device)
58
  attention_mask = tokens["attention_mask"].to(device)
59
 
60
+ # Dynamic shapes (optional, can be used by ToyModel)
61
+ batch_dim = 1
62
+ seq_dim = input_ids.shape[1]
63
+ dynamic_shapes = {
64
+ "input_ids": {0: batch_dim, 1: seq_dim},
65
+ "attention_mask": {0: batch_dim, 1: seq_dim},
66
+ }
67
+
68
+ # --- ToyModel usage ---
69
+ ir = ToyModel(model, (input_ids, attention_mask), dynamic_shapes=dynamic_shapes)
70
+ io_data = ir.predict(input_ids, attention_mask)
71
+ ir.evaluation()
72
 
73
+ # Output
74
+ print("Predicted logits shape:", io_data.shape)
75
+ print("Sample logits:", io_data[0, :5, :5])
76
 
77
  if __name__ == "__main__":
78
  main()