rahul7star commited on
Commit
8145ee9
·
verified ·
1 Parent(s): baff828

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -4,18 +4,18 @@ import torch.nn as nn
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  def main():
7
- # Get Hugging Face token from environment variable
8
  auth_token = os.environ.get("HF_TOKEN")
9
- if auth_token is None:
10
- raise ValueError("Please set your Hugging Face token in the environment variable HF_TOKEN")
11
-
12
- # Model ID
13
- model_id = "google/gemma-3-1b-it"
14
 
15
  # Device
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
- # Wrapper class
 
 
 
19
  class GemmaWrapper(nn.Module):
20
  def __init__(self, model_id, token):
21
  super().__init__()
@@ -28,7 +28,7 @@ def main():
28
  def forward(self, input_ids, attention_mask):
29
  return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
30
 
31
- # Load model and tokenizer
32
  model = GemmaWrapper(model_id, auth_token)
33
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
34
  tokenizer.pad_token = tokenizer.eos_token
@@ -41,10 +41,10 @@ def main():
41
 
42
  # Forward pass
43
  with torch.no_grad():
44
- logits = model(input_ids=input_ids, attention_mask=attention_mask)
45
 
46
  print("Logits shape:", logits.shape)
47
- print("Sample logits:", logits[0, :5, :5]) # show small slice
48
 
49
  if __name__ == "__main__":
50
  main()
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  def main():
7
+ # Get Hugging Face token from env
8
  auth_token = os.environ.get("HF_TOKEN")
9
+ if not auth_token:
10
+ raise ValueError("Please set HF_TOKEN environment variable")
 
 
 
11
 
12
  # Device
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ # Model ID
16
+ model_id = "google/gemma-3-1b-it"
17
+
18
+ # Wrapper
19
  class GemmaWrapper(nn.Module):
20
  def __init__(self, model_id, token):
21
  super().__init__()
 
28
  def forward(self, input_ids, attention_mask):
29
  return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
30
 
31
+ # Load model & tokenizer
32
  model = GemmaWrapper(model_id, auth_token)
33
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
34
  tokenizer.pad_token = tokenizer.eos_token
 
41
 
42
  # Forward pass
43
  with torch.no_grad():
44
+ logits = model(input_ids, attention_mask)
45
 
46
  print("Logits shape:", logits.shape)
47
+ print("Sample logits:", logits[0, :5, :5])
48
 
49
  if __name__ == "__main__":
50
  main()