SandeepU commited on
Commit
841ffa5
·
verified ·
1 Parent(s): 782b268

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. model/model_utils.py +9 -12
model/model_utils.py CHANGED
@@ -1,22 +1,19 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
 
4
  def load_model():
5
- model_name = "Salesforce/codet5-base"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
  model.eval()
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model.to(device)
12
  return tokenizer, model, device
13
 
14
- def generate_explanation(prompt, tokenizer, model, device):
15
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
16
- output = model.generate(
17
- **inputs,
18
- decoder_start_token_id=tokenizer.pad_token_id,
19
- max_new_tokens=256,
20
- temperature=0.7
21
- )
22
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import torch
3
 
4
  def load_model():
5
+ model_name = "mrm8488/codebert-base-finetuned-stackoverflow"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
  model.eval()
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model.to(device)
12
  return tokenizer, model, device
13
 
14
+ def generate_explanation(code, tokenizer, model, device):
15
+ inputs = tokenizer(code, return_tensors="pt", truncation=True, padding=True).to(device)
16
+ with torch.no_grad():
17
+ logits = model(**inputs).logits
18
+ predicted_class_id = logits.argmax().item()
19
+ return f"This code is classified as category ID: {predicted_class_id} (label may vary based on fine-tuning objective)"