code-explainer-c / model /model_utils.py
SandeepU's picture
Upload model_utils.py
841ffa5 verified
raw
history blame
828 Bytes
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
def load_model():
model_name = "mrm8488/codebert-base-finetuned-stackoverflow"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return tokenizer, model, device
def generate_explanation(code, tokenizer, model, device):
inputs = tokenizer(code, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = logits.argmax().item()
return f"This code is classified as category ID: {predicted_class_id} (label may vary based on fine-tuning objective)"