In [1]:
import torch
import gradio as gr
import random
from config import device_type, ckpt_path, GPTConfig, GPT, encode, decode, ctx, num_samples, max_new_tokens, temperature, top_k

checkpoint = torch.load(ckpt_path, map_location=device_type)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device_type)

button_click = False

def fn_query_on_load():
    return "in the air and"

num_samples = 1
def generate_commentary(start):
    start_ids = encode(start)
    x = (torch.tensor(start_ids, dtype=torch.long, device=device_type)[None, ...])

    out_text = ''
    with torch.no_grad():
        with ctx:
            for k in range(num_samples):
                y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
                out_text += decode(y[0].tolist())
                out_text += '\n-o-o-o-o-o-o-o-\n\n'

    return out_text
  
 
def fn_gen_comm(prompt, st, o1, o2, o3):
    '''global button_click
    if not button_click:
        button_click = True
    elif stat == -1:
        button_click = False
        return {
            output1: output1,
            output2: output2,
            output3: output3,
            stat: stat
        }
    
    
    out = generate_commentary(prompt)
    if stat == -1:
        return {
            output1: out,
            output2: None,
            output3: None,
            stat: 0
        }
    
    elif stat == 0:
        return {
            output1: output1,
            output2: out,
            output3: None,
            stat: 1
        }
    
    elif stat == 2:
        return {
            output1: output1,
            output2: output2,
            output3: out,
            stat: -1
        }'''
    
    global button_click
    if not button_click:
        if st == -1:
            button_click = True
    elif st == -1:
        button_click = False
        return {
            output1: o1,
            output2: o2,
            output3: o3,
            stat: -1
        }
    elif st == 2:
        button_click = False
        return {
            output1: o1,
            output2: o2,
            output3: o3,
            stat: -1
        }
    
    out = generate_commentary(prompt)
    if st == -1:
        return {
                output1: out,
                output2: None,
                output3: None,
                stat: 0
        }
    elif st == 0:
        return {
                output1: o1,
                output2: out,
                output3: None,
                stat: 1
        }
    elif st == 1:
        return {
                output1: o1,
                output2: o2,
                output3: out,
                stat: 2
        }


with gr.Blocks() as app:
    with gr.Row():
        gr.Markdown(
            """
            # NanoGPT - Cricket Commentary Generative AI
            ### Give a prompt and see how it comes out with cricket commentary :)
            """)

    with gr.Row(visible=True):
        search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter prompt..', label='Enter Prompt')

    with gr.Row():
        submit_btn = gr.Button("Submit", variant='primary')
        clear_btn = gr.ClearButton()
    with gr.Row():
        with gr.Column():
            output1 = gr.Textbox(lines=10, interactive=False, label='Commentary Box')
            output2 = gr.Textbox(lines=10, interactive=False, label='Commentary Box')
            output3 = gr.Textbox(lines=10, interactive=False, label='Commentary Box')
            stat = gr.State(value=-1)
            

    def clear_data():
        return {
            output1: None,
            output2: None,
            output3: None,
            search_text: None
        }

    clear_btn.click(clear_data, None, [output1, output2, output3, search_text])


    submit_btn.click(
        fn_gen_comm,
        [search_text, stat, output1, output2, output3],
        [output1, output2, output3, stat]
    )
    
    '''output1.change(
        fn_gen_comm,
        search_text,
        [output1, output2, output3, stat]
    )
    
    output2.change(
        fn_gen_comm,
        search_text,
        [output1, output2, output3, stat]
    )

    output3.change(
        fn_gen_comm,
        search_text,
        [output1, output2, output3, stat]
    )'''

'''
Launch the app
'''
app.queue().launch()

number of parameters: 29.94M
Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


