Spaces:
Paused
Paused
import os | |
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
def main(): | |
# Get Hugging Face token from environment variable | |
auth_token = os.environ.get("HF_TOKEN") | |
if auth_token is None: | |
raise ValueError("Please set your Hugging Face token in the environment variable HF_TOKEN") | |
# Model ID | |
model_id = "google/gemma-3-1b-it" | |
# Device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Wrapper class | |
class GemmaWrapper(nn.Module): | |
def __init__(self, model_id, token): | |
super().__init__() | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float32, | |
use_auth_token=token | |
).to(device).eval() | |
def forward(self, input_ids, attention_mask): | |
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits | |
# Load model and tokenizer | |
model = GemmaWrapper(model_id, auth_token) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token) | |
tokenizer.pad_token = tokenizer.eos_token | |
# Example input | |
sentences = ["Hello"] | |
tokens = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True) | |
input_ids = tokens["input_ids"].to(device) | |
attention_mask = tokens["attention_mask"].to(device) | |
# Forward pass | |
with torch.no_grad(): | |
logits = model(input_ids=input_ids, attention_mask=attention_mask) | |
print("Logits shape:", logits.shape) | |
print("Sample logits:", logits[0, :5, :5]) # show small slice | |
if __name__ == "__main__": | |
main() | |