File size: 1,928 Bytes
3821e91
 
 
 
 
 
 
 
749f119
3821e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03a5733
324c700
3821e91
 
 
 
 
 
 
03a5733
 
 
 
 
 
 
 
 
 
3821e91
 
 
 
 
be8fbc5
3821e91
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import gradio as gr
from model import BigramLanguageModel

cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'

model = BigramLanguageModel()
model.load_state_dict(torch.load("nanogpt.pth", map_location=torch.device(device)), strict=False)

# read text file
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# collect all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)

# create a maaping from charaters that occur in this text
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a 

def inference(input_text, max_new_tokens=200):
    context = torch.tensor(encode(input_text), dtype=torch.long, device=device).view(1, -1) 
    
    output_text = decode(model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
    
    return output_text

title = "NanoGPT trained on Shakespeare Plays dataset"
description = "A simple Gradio interface to generate text from gpt model trained on Shakespeare Plays"
examples = [["Shape", 200], 
           ["Answer", 200], 
           ["Ideology", 200], 
           ["Absorb", 200], 
           ["Triangle", 200], 
           ["Listen", 200], 
           ["Census", 200], 
           ["Balance", 200], 
           ["Representative", 200], 
           ["Cinema", 200], 
           ]
demo = gr.Interface(
    inference, 
    inputs = [
        gr.Textbox(label="Enter any word", type="text"),
        gr.Slider(minimum=100, maximum=2000, step=50, value=200, label="Max Character")
        ], 
    outputs = [
        gr.Textbox(label="Output", type="text")
        ],
    title = title,
    description = description,
    examples = examples,
)
demo.launch()