SandeepU commited on
Commit
fd3396f
·
verified ·
1 Parent(s): 79881fe

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. model/model_utils.py +11 -13
model/model_utils.py CHANGED
@@ -1,19 +1,17 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import torch
3
 
4
  def load_model():
5
- model_name = "mrm8488/codebert-base-finetuned-stackoverflow"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
  model.eval()
 
 
9
 
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- model.to(device)
12
- return tokenizer, model, device
13
-
14
- def generate_explanation(code, tokenizer, model, device):
15
- inputs = tokenizer(code, return_tensors="pt", truncation=True, padding=True).to(device)
16
- with torch.no_grad():
17
- logits = model(**inputs).logits
18
- predicted_class_id = logits.argmax().item()
19
- return f"This code is classified as category ID: {predicted_class_id} (label may vary based on fine-tuning objective)"
 
1
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
2
  import torch
3
 
4
  def load_model():
5
+ model_name = "Salesforce/codet5-small"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
8
  model.eval()
9
+ model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
10
+ return tokenizer, model
11
 
12
+ def generate_explanation(code, tokenizer, model):
13
+ device = model.device
14
+ input_text = "summarize: " + code
15
+ input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True).to(device)
16
+ output = model.generate(input_ids, max_new_tokens=150, early_stopping=True)
17
+ return tokenizer.decode(output[0], skip_special_tokens=True)