Sarthak-506's picture
Upload folder using huggingface_hub
2cabb35 verified
import gradio as gr
from peft import PeftModel
from transformers import RobertaTokenizer, T5ForConditionalGeneration
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "Salesforce/codet5-base"
new_model_id = 'Salesforce/codet5-base-multi-sum'
tokenizer = RobertaTokenizer.from_pretrained(model_id, torch_dtype=torch.float16, device_map=device)
old_model = T5ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map=device)
old_model.eval()
# base_model = T5ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map=device)
# fine_tuned_model = PeftModel.from_pretrained(base_model, '/kaggle/input/codet5-fine-tuned/pytorch/v1/1/codet5-finetuned', is_trainable=False)
fine_tuned_model = T5ForConditionalGeneration.from_pretrained(new_model_id, torch_dtype=torch.float16, device_map=device)
fine_tuned_model.eval()
# Function to generate predictions
def generate_docstring(code, max_new_tokens, model_choice):
tokenized_input = tokenizer(
code,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
).to(device)
if model_choice == "Base Model":
model_to_use = old_model
else:
model_to_use = fine_tuned_model
output = model_to_use.generate(
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask'],
max_new_tokens=max_new_tokens,
num_beams=5,
length_penalty=1.0,
early_stopping=True
)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Create the Gradio UI
demo = gr.Interface(
fn=generate_docstring,
inputs=[
gr.Textbox(lines=6, label="Enter Code"),
gr.Slider(10, 300, value=100, step=10, label="Max new tokens"),
gr.Dropdown(label="Model Version", choices=["Base Model", "Fine-tuned Model"], value="Fine-tuned Model")
],
outputs=gr.Text(label="Generated Docstring"),
title="🧠 CodeT5: Docstring Generator",
description="Select between the base and fine-tuned CodeT5 model to generate docstrings from code input."
)
# Launch with Gradio Sharing Public Link
demo.launch(share=True)