File size: 2,088 Bytes
cdb697b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# app.py
import gradio as gr
import torch
from model import GPTModel  # Import your specific GPT model class
from transformers import PreTrainedTokenizerFast

# Load model and tokenizer once at startup
def load_model_n_tokenizer():
    model = GPTModel.from_pretrained("Aananda-giri/GPT2-Nepali")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    tokenizer = PreTrainedTokenizerFast.from_pretrained("Aananda-giri/NepaliBPE")
    return model, tokenizer

# Initialize at startup
model, tokenizer = load_model_n_tokenizer()
model.eval()

def generate(prompt, max_new_tokens, top_k, temperature, repetition_penalty, penalize_len_below):
    device = next(model.parameters()).device
    
    with torch.no_grad():
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
        
        outputs = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            min_length=penalize_len_below,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Create Gradio interface
interface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Enter Nepali text here..."),
        gr.Slider(minimum=1, maximum=512, value=50, step=1, label="Max New Tokens"),
        gr.Slider(minimum=1, maximum=100, value=3, step=1, label="Top K"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty"),
        gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Minimum Length Penalty"),
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="Nepali GPT-2 Text Generator",
    description="Enter Nepali text to generate content using the custom GPT-2 model."
)

interface.launch()