File size: 812 Bytes
fd3396f
2cfa511
 
 
db97ee5
2cfa511
fd3396f
2cfa511
fd3396f
 
2cfa511
fd3396f
 
e8e6762
 
fd3396f
e8e6762
fd3396f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import AutoTokenizer, T5ForConditionalGeneration
import torch

def load_model():
    model_name = "Salesforce/codet5-base-multi-sum"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    model.eval()
    model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    return tokenizer, model

def generate_explanation(code, tokenizer, model):
    device = model.device
    # Final prompt style: generate docstring
    input_text = f"generate docstring: {code.strip()}"
    input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True).to(device)
    output = model.generate(input_ids, max_new_tokens=150, early_stopping=True)
    return tokenizer.decode(output[0], skip_special_tokens=True)