|
import gradio as gr |
|
from transformers import GPT2Tokenizer |
|
from model import TransformerModel |
|
import torch |
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
model = TransformerModel( |
|
vocab_size=tokenizer.vocab_size, |
|
hidden_size=512, |
|
num_layers=12, |
|
num_heads=16, |
|
dropout=0.1 |
|
) |
|
model.load_state_dict(torch.load("Conv_GPT_finetuned_blended_skill.pth", map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
|
|
|
|
def generate_text(prompt, max_new_tokens=30): |
|
input_ids = tokenizer.encode(prompt, return_tensors='pt') |
|
|
|
if input_ids.size(1) > 512: |
|
input_ids = input_ids[:, :512] |
|
generated_ids = input_ids |
|
with torch.no_grad(): |
|
for _ in range(max_new_tokens): |
|
logits = model(generated_ids) |
|
next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) |
|
generated_ids = torch.cat([generated_ids, next_token], dim=1) |
|
|
|
if generated_ids.size(1) > 512: |
|
generated_ids = generated_ids[:, -512:] |
|
if tokenizer.decode(next_token.item()) == '\n': |
|
break |
|
return tokenizer.decode(generated_ids[0, len(input_ids[0]):]).strip() |
|
|
|
|
|
def chat(message, history): |
|
prompt = f"User: {message}\nAssistant:" |
|
response = generate_text(prompt) |
|
return response |
|
|
|
|
|
interface = gr.ChatInterface( |
|
fn=chat, |
|
title="Conv GPT π¬", |
|
description="""Welcome to Conv GPT, a custom-trained transformer-based chatbot from SCRATCH! This model is trained on the DailyDialog dataset and fine-tuned on the BlendedTalk dataset to |
|
provide conversational responses.""", |
|
theme="default", |
|
examples=["How are you?", "What is your favorite hobby?."] |
|
) |
|
|
|
|
|
|
|
interface.launch(server_name="0.0.0.0", server_port=7860) |