File size: 2,241 Bytes
2cabb35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import gradio as gr
from peft import PeftModel
from transformers import RobertaTokenizer, T5ForConditionalGeneration
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "Salesforce/codet5-base"
new_model_id = 'Salesforce/codet5-base-multi-sum'
tokenizer = RobertaTokenizer.from_pretrained(model_id, torch_dtype=torch.float16, device_map=device)
old_model = T5ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map=device)
old_model.eval()
# base_model = T5ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map=device)
# fine_tuned_model = PeftModel.from_pretrained(base_model, '/kaggle/input/codet5-fine-tuned/pytorch/v1/1/codet5-finetuned', is_trainable=False)
fine_tuned_model = T5ForConditionalGeneration.from_pretrained(new_model_id, torch_dtype=torch.float16, device_map=device)
fine_tuned_model.eval()
# Function to generate predictions
def generate_docstring(code, max_new_tokens, model_choice):
tokenized_input = tokenizer(
code,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
).to(device)
if model_choice == "Base Model":
model_to_use = old_model
else:
model_to_use = fine_tuned_model
output = model_to_use.generate(
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask'],
max_new_tokens=max_new_tokens,
num_beams=5,
length_penalty=1.0,
early_stopping=True
)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Create the Gradio UI
demo = gr.Interface(
fn=generate_docstring,
inputs=[
gr.Textbox(lines=6, label="Enter Code"),
gr.Slider(10, 300, value=100, step=10, label="Max new tokens"),
gr.Dropdown(label="Model Version", choices=["Base Model", "Fine-tuned Model"], value="Fine-tuned Model")
],
outputs=gr.Text(label="Generated Docstring"),
title="🧠 CodeT5: Docstring Generator",
description="Select between the base and fine-tuned CodeT5 model to generate docstrings from code input."
)
# Launch with Gradio Sharing Public Link
demo.launch(share=True) |