Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,114 +1,50 @@
|
|
|
|
1 |
import torch
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
-
|
4 |
import torch.nn as nn
|
5 |
-
from torch.export import Dim
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
-
from huggingface_hub import login
|
8 |
-
|
9 |
-
def calculate_self_attention(Q, K, V, masked_fill=None, scale=None, epsilon=1e-9):
|
10 |
-
"""
|
11 |
-
Args:
|
12 |
-
Q, K, V: [B, H, T_q, D] or [B, H, T_k, D]
|
13 |
-
masked_fill: Optional additive mask of shape [B, 1, T_q, T_k] or [B, H, T_q, T_k]
|
14 |
-
scale: Optional scaling factor (default: sqrt(D))
|
15 |
-
epsilon: Small constant for numerical stability
|
16 |
-
"""
|
17 |
-
print("Q dtype torch:", Q.dtype)
|
18 |
-
print("K dtype torch:", K.dtype)
|
19 |
-
print("V dtype torch:", V.dtype)
|
20 |
-
|
21 |
-
B, H, T_q, D = Q.shape
|
22 |
-
T_k = K.shape[2] # number of key tokens
|
23 |
-
scale = scale or np.sqrt(D)
|
24 |
-
|
25 |
-
log(f"Q: {np.sum(Q):.4f}, K: {np.sum(K):.4f}, V: {np.sum(V):.4f}")
|
26 |
-
|
27 |
-
# Step 1: Raw attention logits
|
28 |
-
QK_output = np.matmul(Q, K.transpose(0, 1, 3, 2)) # [B, H, T_q, T_k]
|
29 |
-
logits_unmasked = QK_output / scale
|
30 |
-
print(f"QK_output--- shape: {QK_output.shape}, value: {np.sum(QK_output):.4f}")
|
31 |
-
print(f"logits_unmasked--- shape: {logits_unmasked.shape}, value: {np.sum(logits_unmasked):.4f}")
|
32 |
-
|
33 |
-
###########################################################################
|
34 |
-
def _nz_stats(name, arr, tol=1e-12):
|
35 |
-
total = arr.size
|
36 |
-
zeros = np.count_nonzero(np.abs(arr) < tol)
|
37 |
-
nonzeros = total - zeros
|
38 |
-
pct = (nonzeros / total) * 100
|
39 |
-
print(f"{name}: nonzeros={nonzeros} ({pct:.2f}%), zeros={zeros}")
|
40 |
-
|
41 |
-
# Debug: non-zero stats
|
42 |
-
_nz_stats("Q: ", Q)
|
43 |
-
_nz_stats("K: ", K)
|
44 |
-
_nz_stats("V: ", V)
|
45 |
-
_nz_stats("QK_output: ", QK_output)
|
46 |
-
###########################################################################
|
47 |
-
|
48 |
-
# Step 2: Softmax over unmasked logits (for debugging or interpretability)
|
49 |
-
A = np.exp(logits_unmasked - np.max(logits_unmasked, axis=-1, keepdims=True))
|
50 |
-
A = A / (np.sum(A, axis=-1, keepdims=True) + epsilon)
|
51 |
-
log(f"A (unmasked attention weights) --- shape: {A.shape}, value: {np.sum(A):.4f}")
|
52 |
-
|
53 |
-
# Step 3: Apply additive attention mask (optional)
|
54 |
-
masked_fill = None
|
55 |
-
if masked_fill is not None:
|
56 |
-
logits_masked = logits_unmasked + masked_fill # [B, H, T, T] + [B, 1, T, T]
|
57 |
-
log(f"masked_fill--- minimum: {np.min(masked_fill)}, maximum: {np.max(masked_fill)}")
|
58 |
-
else:
|
59 |
-
logits_masked = logits_unmasked.copy()
|
60 |
-
log(f"logits_masked --- shape: {logits_masked.shape}, value: {np.sum(logits_masked):.4f}")
|
61 |
-
|
62 |
-
# Step 4: Softmax over masked logits
|
63 |
-
A_masked = np.exp(logits_masked - np.max(logits_masked, axis=-1, keepdims=True))
|
64 |
-
A_masked = A_masked / (np.sum(A_masked, axis=-1, keepdims=True) + epsilon)
|
65 |
-
log(f"A_masked (masked attention weights)--- shape: {A_masked.shape}, value: {np.sum(A_masked):.4f}")
|
66 |
-
|
67 |
-
# Step 5: Compute attention output using masked weights
|
68 |
-
attention_output = np.matmul(A_masked, V) # [B, H, T_q, D]
|
69 |
-
log(f"attention_output (using A_masked)--- shape: {attention_output.shape}, value: {np.sum(attention_output):.4f}")
|
70 |
|
71 |
def main():
|
72 |
-
|
73 |
-
|
74 |
-
auth_token
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
112 |
|
113 |
if __name__ == "__main__":
|
114 |
main()
|
|
|
1 |
+
import os
|
2 |
import torch
|
|
|
|
|
3 |
import torch.nn as nn
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def main():
|
7 |
+
# Get Hugging Face token from environment variable
|
8 |
+
auth_token = os.environ.get("HF_TOKEN")
|
9 |
+
if auth_token is None:
|
10 |
+
raise ValueError("Please set your Hugging Face token in the environment variable HF_TOKEN")
|
11 |
+
|
12 |
+
# Model ID
|
13 |
+
model_id = "google/gemma-3-1b-it"
|
14 |
+
|
15 |
+
# Device
|
16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
|
18 |
+
# Wrapper class
|
19 |
+
class GemmaWrapper(nn.Module):
|
20 |
+
def __init__(self, model_id, token):
|
21 |
+
super().__init__()
|
22 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
23 |
+
model_id,
|
24 |
+
torch_dtype=torch.float32,
|
25 |
+
use_auth_token=token
|
26 |
+
).to(device).eval()
|
27 |
+
|
28 |
+
def forward(self, input_ids, attention_mask):
|
29 |
+
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
|
30 |
+
|
31 |
+
# Load model and tokenizer
|
32 |
+
model = GemmaWrapper(model_id, auth_token)
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
|
34 |
+
tokenizer.pad_token = tokenizer.eos_token
|
35 |
+
|
36 |
+
# Example input
|
37 |
+
sentences = ["Hello"]
|
38 |
+
tokens = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
|
39 |
+
input_ids = tokens["input_ids"].to(device)
|
40 |
+
attention_mask = tokens["attention_mask"].to(device)
|
41 |
+
|
42 |
+
# Forward pass
|
43 |
+
with torch.no_grad():
|
44 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask)
|
45 |
+
|
46 |
+
print("Logits shape:", logits.shape)
|
47 |
+
print("Sample logits:", logits[0, :5, :5]) # show small slice
|
48 |
|
49 |
if __name__ == "__main__":
|
50 |
main()
|