rahul7star commited on
Commit
baff828
·
verified ·
1 Parent(s): 8e2085e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -106
app.py CHANGED
@@ -1,114 +1,50 @@
 
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
-
4
  import torch.nn as nn
5
- from torch.export import Dim
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
- from huggingface_hub import login
8
-
9
- def calculate_self_attention(Q, K, V, masked_fill=None, scale=None, epsilon=1e-9):
10
- """
11
- Args:
12
- Q, K, V: [B, H, T_q, D] or [B, H, T_k, D]
13
- masked_fill: Optional additive mask of shape [B, 1, T_q, T_k] or [B, H, T_q, T_k]
14
- scale: Optional scaling factor (default: sqrt(D))
15
- epsilon: Small constant for numerical stability
16
- """
17
- print("Q dtype torch:", Q.dtype)
18
- print("K dtype torch:", K.dtype)
19
- print("V dtype torch:", V.dtype)
20
-
21
- B, H, T_q, D = Q.shape
22
- T_k = K.shape[2] # number of key tokens
23
- scale = scale or np.sqrt(D)
24
-
25
- log(f"Q: {np.sum(Q):.4f}, K: {np.sum(K):.4f}, V: {np.sum(V):.4f}")
26
-
27
- # Step 1: Raw attention logits
28
- QK_output = np.matmul(Q, K.transpose(0, 1, 3, 2)) # [B, H, T_q, T_k]
29
- logits_unmasked = QK_output / scale
30
- print(f"QK_output--- shape: {QK_output.shape}, value: {np.sum(QK_output):.4f}")
31
- print(f"logits_unmasked--- shape: {logits_unmasked.shape}, value: {np.sum(logits_unmasked):.4f}")
32
-
33
- ###########################################################################
34
- def _nz_stats(name, arr, tol=1e-12):
35
- total = arr.size
36
- zeros = np.count_nonzero(np.abs(arr) < tol)
37
- nonzeros = total - zeros
38
- pct = (nonzeros / total) * 100
39
- print(f"{name}: nonzeros={nonzeros} ({pct:.2f}%), zeros={zeros}")
40
-
41
- # Debug: non-zero stats
42
- _nz_stats("Q: ", Q)
43
- _nz_stats("K: ", K)
44
- _nz_stats("V: ", V)
45
- _nz_stats("QK_output: ", QK_output)
46
- ###########################################################################
47
-
48
- # Step 2: Softmax over unmasked logits (for debugging or interpretability)
49
- A = np.exp(logits_unmasked - np.max(logits_unmasked, axis=-1, keepdims=True))
50
- A = A / (np.sum(A, axis=-1, keepdims=True) + epsilon)
51
- log(f"A (unmasked attention weights) --- shape: {A.shape}, value: {np.sum(A):.4f}")
52
-
53
- # Step 3: Apply additive attention mask (optional)
54
- masked_fill = None
55
- if masked_fill is not None:
56
- logits_masked = logits_unmasked + masked_fill # [B, H, T, T] + [B, 1, T, T]
57
- log(f"masked_fill--- minimum: {np.min(masked_fill)}, maximum: {np.max(masked_fill)}")
58
- else:
59
- logits_masked = logits_unmasked.copy()
60
- log(f"logits_masked --- shape: {logits_masked.shape}, value: {np.sum(logits_masked):.4f}")
61
-
62
- # Step 4: Softmax over masked logits
63
- A_masked = np.exp(logits_masked - np.max(logits_masked, axis=-1, keepdims=True))
64
- A_masked = A_masked / (np.sum(A_masked, axis=-1, keepdims=True) + epsilon)
65
- log(f"A_masked (masked attention weights)--- shape: {A_masked.shape}, value: {np.sum(A_masked):.4f}")
66
-
67
- # Step 5: Compute attention output using masked weights
68
- attention_output = np.matmul(A_masked, V) # [B, H, T_q, D]
69
- log(f"attention_output (using A_masked)--- shape: {attention_output.shape}, value: {np.sum(attention_output):.4f}")
70
 
71
  def main():
72
-
73
- model_id = “google/gemma-3-1b-it”
74
- auth_token = “[Here, place your Huggingface Authentication Token]”
75
-
76
-
77
- class GemmaWrapper(nn.Module):
78
- def __init__(self, model_id, token):
79
- super().__init__()
80
- self.model = AutoModelForCausalLM.from_pretrained(
81
- model_id,
82
- torch_dtype=torch.float32,
83
- token=token
84
- ).eval()
85
-
86
- def forward(self, input_ids, attention_mask):
87
- return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
88
-
89
-
90
-
91
- model = GemmaWrapper(model_id, auth_token)
92
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
93
- tokenizer.pad_token = tokenizer.eos_token
94
-
95
- sentences = [“Hello”]
96
-
97
- tokens = tokenizer(sentences, return_tensors=“pt”, padding=True, truncation=True)
98
- input_ids = tokens[“input_ids”]
99
- attention_mask = tokens[“attention_mask”]
100
-
101
- if len(sentences) > 1:
102
- batch_dim = Dim(“batch”, min=1, max=len(sentences))
103
- else:
104
- batch_dim = 1 # Static dimension
105
-
106
- seq_dim = Dim(“seq”, min=1, max=input_ids.shape[1])
107
-
108
- dynamic_shapes = {
109
- “input_ids”: {0: batch_dim, 1: seq_dim},
110
- “attention_mask”: {0: batch_dim, 1: seq_dim},
111
- }
 
112
 
113
  if __name__ == "__main__":
114
  main()
 
1
+ import os
2
  import torch
 
 
3
  import torch.nn as nn
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def main():
7
+ # Get Hugging Face token from environment variable
8
+ auth_token = os.environ.get("HF_TOKEN")
9
+ if auth_token is None:
10
+ raise ValueError("Please set your Hugging Face token in the environment variable HF_TOKEN")
11
+
12
+ # Model ID
13
+ model_id = "google/gemma-3-1b-it"
14
+
15
+ # Device
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ # Wrapper class
19
+ class GemmaWrapper(nn.Module):
20
+ def __init__(self, model_id, token):
21
+ super().__init__()
22
+ self.model = AutoModelForCausalLM.from_pretrained(
23
+ model_id,
24
+ torch_dtype=torch.float32,
25
+ use_auth_token=token
26
+ ).to(device).eval()
27
+
28
+ def forward(self, input_ids, attention_mask):
29
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
30
+
31
+ # Load model and tokenizer
32
+ model = GemmaWrapper(model_id, auth_token)
33
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+
36
+ # Example input
37
+ sentences = ["Hello"]
38
+ tokens = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
39
+ input_ids = tokens["input_ids"].to(device)
40
+ attention_mask = tokens["attention_mask"].to(device)
41
+
42
+ # Forward pass
43
+ with torch.no_grad():
44
+ logits = model(input_ids=input_ids, attention_mask=attention_mask)
45
+
46
+ print("Logits shape:", logits.shape)
47
+ print("Sample logits:", logits[0, :5, :5]) # show small slice
48
 
49
  if __name__ == "__main__":
50
  main()