import torch import tiktoken import gradio as gr import torch.nn.functional as F from model import GPT, GPTConfig device = 'cpu' if torch.cuda.is_available(): device = 'cuda' elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" model = GPT(GPTConfig()) model.load_state_dict(torch.load("nanogpt.pth", map_location=torch.device(device)), strict=False) model.to(device) enc = tiktoken.get_encoding('gpt2') def inference(input_text, num_return_sequences, max_length): input_tokens = torch.tensor(enc.encode(input_text), dtype=torch.long) input_tokens = input_tokens.unsqueeze(0).repeat(num_return_sequences, 1) x = input_tokens.to('cuda') while x.size(1) < max_length: # forward the model to get the logits with torch.no_grad(): logits = model(x)[0] # (B, T, vocab_size) # take the logits at the last position logits = logits[:, -1, :] # (B, vocab_size) # get the probabilities probs = F.softmax(logits, dim=-1) # do top-k sampling of 50 (huggingface pipeline default) # topk_probs here becomes (5, 50), topk_indices is (5, 50) topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # select a token from the top-k probabilities # note: multinomial does not demand the input to sum to 1 ix = torch.multinomial(topk_probs, 1) # (B, 1) # gather the corresponding indices xcol = torch.gather(topk_indices, -1, ix) # (B, 1) # append to the sequence x = torch.cat((x, xcol), dim=1) decode_list = [] # print the generated text for i in range(num_return_sequences): tokens = x[i, :max_length].tolist() decoded = enc.decode(tokens) decode_list.append(decoded) output = "\n======\n".join(decode_list) return output title = "GPT-2 trained on Shakespeare Plays dataset" description = "A simple Gradio interface to generate text from GPT-2 model trained on Shakespeare Plays" examples = [["Please put on these earmuffs because I can't you hear.", 2, 20], ["Twin 4-month-olds slept in the shade of the palm tree while the mother tanned in the sun.", 2, 20], ["Happiness can be found in the depths of chocolate pudding.", 2, 20], ["Seek success, but always be prepared for random cats.", 2, 20], ["This made him feel like an old-style rootbeer float smells.", 2, 20], ["The view from the lighthouse excited even the most seasoned traveler.", 2, 20], ["I've always wanted to go to Tajikistan, but my cat would miss me.", 2, 20], ["He found rain fascinating yet unpleasant.", 2, 20], ["Plans for this weekend include turning wine into water.", 2, 20], ["Iron pyrite is the most foolish of all minerals.", 2, 20], ] demo = gr.Interface( inference, inputs = [ gr.Textbox(label="Enter some text", type="text"), gr.Slider(minimum=1, maximum=5, step=1, value=5, label="Number of outputs"), gr.Slider(minimum=10, maximum=30, step=1, value=20, label="Maximum lenght of a sequence") ], outputs = [ gr.Textbox(label="Output", type="text") ], title = title, description = description, examples = examples, )