Spaces:
Runtime error
Runtime error
File size: 2,088 Bytes
cdb697b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# app.py
import gradio as gr
import torch
from model import GPTModel # Import your specific GPT model class
from transformers import PreTrainedTokenizerFast
# Load model and tokenizer once at startup
def load_model_n_tokenizer():
model = GPTModel.from_pretrained("Aananda-giri/GPT2-Nepali")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained("Aananda-giri/NepaliBPE")
return model, tokenizer
# Initialize at startup
model, tokenizer = load_model_n_tokenizer()
model.eval()
def generate(prompt, max_new_tokens, top_k, temperature, repetition_penalty, penalize_len_below):
device = next(model.parameters()).device
with torch.no_grad():
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
min_length=penalize_len_below,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Create Gradio interface
interface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter Nepali text here..."),
gr.Slider(minimum=1, maximum=512, value=50, step=1, label="Max New Tokens"),
gr.Slider(minimum=1, maximum=100, value=3, step=1, label="Top K"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty"),
gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Minimum Length Penalty"),
],
outputs=gr.Textbox(label="Generated Text"),
title="Nepali GPT-2 Text Generator",
description="Enter Nepali text to generate content using the custom GPT-2 model."
)
interface.launch() |