Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -4,18 +4,18 @@ import torch.nn as nn
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
|
6 |
def main():
|
7 |
-
# Get Hugging Face token from
|
8 |
auth_token = os.environ.get("HF_TOKEN")
|
9 |
-
if auth_token
|
10 |
-
raise ValueError("Please set
|
11 |
-
|
12 |
-
# Model ID
|
13 |
-
model_id = "google/gemma-3-1b-it"
|
14 |
|
15 |
# Device
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
-
#
|
|
|
|
|
|
|
19 |
class GemmaWrapper(nn.Module):
|
20 |
def __init__(self, model_id, token):
|
21 |
super().__init__()
|
@@ -28,7 +28,7 @@ def main():
|
|
28 |
def forward(self, input_ids, attention_mask):
|
29 |
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
|
30 |
|
31 |
-
# Load model
|
32 |
model = GemmaWrapper(model_id, auth_token)
|
33 |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
|
34 |
tokenizer.pad_token = tokenizer.eos_token
|
@@ -41,10 +41,10 @@ def main():
|
|
41 |
|
42 |
# Forward pass
|
43 |
with torch.no_grad():
|
44 |
-
logits = model(input_ids
|
45 |
|
46 |
print("Logits shape:", logits.shape)
|
47 |
-
print("Sample logits:", logits[0, :5, :5])
|
48 |
|
49 |
if __name__ == "__main__":
|
50 |
main()
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
|
6 |
def main():
|
7 |
+
# Get Hugging Face token from env
|
8 |
auth_token = os.environ.get("HF_TOKEN")
|
9 |
+
if not auth_token:
|
10 |
+
raise ValueError("Please set HF_TOKEN environment variable")
|
|
|
|
|
|
|
11 |
|
12 |
# Device
|
13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
|
15 |
+
# Model ID
|
16 |
+
model_id = "google/gemma-3-1b-it"
|
17 |
+
|
18 |
+
# Wrapper
|
19 |
class GemmaWrapper(nn.Module):
|
20 |
def __init__(self, model_id, token):
|
21 |
super().__init__()
|
|
|
28 |
def forward(self, input_ids, attention_mask):
|
29 |
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
|
30 |
|
31 |
+
# Load model & tokenizer
|
32 |
model = GemmaWrapper(model_id, auth_token)
|
33 |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
|
34 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
41 |
|
42 |
# Forward pass
|
43 |
with torch.no_grad():
|
44 |
+
logits = model(input_ids, attention_mask)
|
45 |
|
46 |
print("Logits shape:", logits.shape)
|
47 |
+
print("Sample logits:", logits[0, :5, :5])
|
48 |
|
49 |
if __name__ == "__main__":
|
50 |
main()
|