SandeepU commited on
Commit
db97ee5
·
verified ·
1 Parent(s): bf2c260

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. model/model_utils.py +4 -3
model/model_utils.py CHANGED
@@ -2,7 +2,7 @@ from transformers import AutoTokenizer, T5ForConditionalGeneration
2
  import torch
3
 
4
  def load_model():
5
- model_name = "Salesforce/codet5-small"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = T5ForConditionalGeneration.from_pretrained(model_name)
8
  model.eval()
@@ -11,7 +11,8 @@ def load_model():
11
 
12
  def generate_explanation(code, tokenizer, model):
13
  device = model.device
14
- input_text = "summarize: " + code
 
15
  input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True).to(device)
16
- output = model.generate(input_ids, max_new_tokens=150, early_stopping=True)
17
  return tokenizer.decode(output[0], skip_special_tokens=True)
 
2
  import torch
3
 
4
  def load_model():
5
+ model_name = "Salesforce/codet5-base-multi-sum"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = T5ForConditionalGeneration.from_pretrained(model_name)
8
  model.eval()
 
11
 
12
  def generate_explanation(code, tokenizer, model):
13
  device = model.device
14
+ # Better prompt engineering
15
+ input_text = f"summarize: This Python function does the following: {code}"
16
  input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True).to(device)
17
+ output = model.generate(input_ids, max_new_tokens=200, early_stopping=True)
18
  return tokenizer.decode(output[0], skip_special_tokens=True)