Gemma3 / app.py
rahul7star's picture
Update app.py
baff828 verified
raw
history blame
1.69 kB
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
def main():
# Get Hugging Face token from environment variable
auth_token = os.environ.get("HF_TOKEN")
if auth_token is None:
raise ValueError("Please set your Hugging Face token in the environment variable HF_TOKEN")
# Model ID
model_id = "google/gemma-3-1b-it"
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Wrapper class
class GemmaWrapper(nn.Module):
def __init__(self, model_id, token):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
use_auth_token=token
).to(device).eval()
def forward(self, input_ids, attention_mask):
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
# Load model and tokenizer
model = GemmaWrapper(model_id, auth_token)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
tokenizer.pad_token = tokenizer.eos_token
# Example input
sentences = ["Hello"]
tokens = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
input_ids = tokens["input_ids"].to(device)
attention_mask = tokens["attention_mask"].to(device)
# Forward pass
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask)
print("Logits shape:", logits.shape)
print("Sample logits:", logits[0, :5, :5]) # show small slice
if __name__ == "__main__":
main()