Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,78 +1,35 @@
|
|
1 |
-
import os
|
2 |
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
|
6 |
-
#
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
super().__init__()
|
40 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
41 |
-
model_id,
|
42 |
-
torch_dtype=torch.float32,
|
43 |
-
use_auth_token=token
|
44 |
-
).to(device).eval()
|
45 |
-
|
46 |
-
def forward(self, input_ids, attention_mask):
|
47 |
-
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
|
48 |
-
|
49 |
-
# Load model & tokenizer
|
50 |
-
model = GemmaWrapper(model_id, auth_token)
|
51 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
|
52 |
-
tokenizer.pad_token = tokenizer.eos_token
|
53 |
-
|
54 |
-
# Example input
|
55 |
-
sentences = ["Hello"]
|
56 |
-
tokens = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
|
57 |
-
input_ids = tokens["input_ids"].to(device)
|
58 |
-
attention_mask = tokens["attention_mask"].to(device)
|
59 |
-
|
60 |
-
# Dynamic shapes (optional, can be used by ToyModel)
|
61 |
-
batch_dim = 1
|
62 |
-
seq_dim = input_ids.shape[1]
|
63 |
-
dynamic_shapes = {
|
64 |
-
"input_ids": {0: batch_dim, 1: seq_dim},
|
65 |
-
"attention_mask": {0: batch_dim, 1: seq_dim},
|
66 |
-
}
|
67 |
-
|
68 |
-
# --- ToyModel usage ---
|
69 |
-
ir = ToyModel(model, (input_ids, attention_mask), dynamic_shapes=dynamic_shapes)
|
70 |
-
io_data = ir.predict(input_ids, attention_mask)
|
71 |
-
ir.evaluation()
|
72 |
-
|
73 |
-
# Output
|
74 |
-
print("Predicted logits shape:", io_data.shape)
|
75 |
-
print("Sample logits:", io_data[0, :5, :5])
|
76 |
-
|
77 |
-
if __name__ == "__main__":
|
78 |
-
main()
|
|
|
|
|
1 |
import torch
|
|
|
|
|
2 |
|
3 |
+
# Utility to log tensor info
|
4 |
+
def log_tensor(name, x):
|
5 |
+
print(f"--- {name} ---")
|
6 |
+
print(f"shape: {x.shape}, dtype: {x.dtype}, device: {x.device}")
|
7 |
+
print(f"min: {x.min().item():.6f}, max: {x.max().item():.6f}, mean: {x.mean().item():.6f}, sum: {x.sum().item():.6f}")
|
8 |
+
print(f"full tensor:\n{x}\n")
|
9 |
+
|
10 |
+
# Simple function
|
11 |
+
def g(x, y):
|
12 |
+
log_tensor("g input x", x)
|
13 |
+
log_tensor("g input y", y)
|
14 |
+
z = x + y
|
15 |
+
log_tensor("g output z", z)
|
16 |
+
return z
|
17 |
+
|
18 |
+
# Compiled function
|
19 |
+
@torch.compile(backend="eager")
|
20 |
+
def f(x):
|
21 |
+
log_tensor("f input x", x)
|
22 |
+
x = torch.sin(x)
|
23 |
+
log_tensor("f after torch.sin(x)", x)
|
24 |
+
x = g(x, x)
|
25 |
+
log_tensor("f after g(x, x)", x)
|
26 |
+
return x
|
27 |
+
|
28 |
+
# Example input
|
29 |
+
x = torch.ones(3, 3, dtype=torch.float32)
|
30 |
+
log_tensor("original input x", x)
|
31 |
+
|
32 |
+
# Run compiled function
|
33 |
+
out = f(x)
|
34 |
+
|
35 |
+
log_tensor("final output", out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|