Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import random | |
| from config import device_type, ckpt_path, GPTConfig, GPT, encode, decode, ctx, num_samples, max_new_tokens, temperature, top_k | |
| checkpoint = torch.load(ckpt_path, map_location=device_type) | |
| gptconf = GPTConfig(**checkpoint['model_args']) | |
| model = GPT(gptconf) | |
| state_dict = checkpoint['model'] | |
| unwanted_prefix = '_orig_mod.' | |
| for k,v in list(state_dict.items()): | |
| if k.startswith(unwanted_prefix): | |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| model.to(device_type) | |
| def fn_query_on_load(): | |
| return "in the air and" | |
| def generate_commentary(start): | |
| start_ids = encode(start) | |
| x = (torch.tensor(start_ids, dtype=torch.long, device=device_type)[None, ...]) | |
| out_text = '' | |
| with torch.no_grad(): | |
| with ctx: | |
| for k in range(num_samples): | |
| y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) | |
| out_text += decode(y[0].tolist()) | |
| out_text += '\n-o-o-o-o-o-o-o-\n\n' | |
| return { | |
| output: out_text | |
| } | |
| with gr.Blocks() as app: | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| # NanoGPT - Cricket Commentary Generative AI | |
| ### Give a prompt and see how it comes out with cricket commentary :) | |
| """) | |
| with gr.Row(visible=True): | |
| search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter prompt..', label='Enter Prompt') | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", variant='primary') | |
| clear_btn = gr.ClearButton() | |
| with gr.Row(): | |
| with gr.Row(): | |
| output = gr.Textbox(lines=15, interactive=False, label='Commentary Box') | |
| def clear_data(): | |
| return { | |
| output: None, | |
| search_text: None | |
| } | |
| clear_btn.click(clear_data, None, [output, search_text]) | |
| submit_btn.click( | |
| generate_commentary, | |
| search_text, | |
| output | |
| ) | |
| ''' | |
| Launch the app | |
| ''' | |
| app.queue().launch() |