SandeepU commited on
Commit
2cfa511
·
verified ·
1 Parent(s): ee194e2

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. model/model_utils.py +17 -0
model/model_utils.py CHANGED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+
4
+ def load_model():
5
+ model_name = "Salesforce/codet5-base"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForCausalLM.from_pretrained(model_name)
8
+ model.eval()
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
+ return tokenizer, model, device
13
+
14
+ def generate_explanation(prompt, tokenizer, model, device):
15
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
16
+ output = model.generate(**inputs, max_new_tokens=256, temperature=0.7)
17
+ return tokenizer.decode(output[0], skip_special_tokens=True)