Conv_GPT / app.py
nnsohamnn's picture
Update app.py
a174693 verified
import gradio as gr
from transformers import GPT2Tokenizer
from model import TransformerModel
import torch
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# Load model
model = TransformerModel(
vocab_size=tokenizer.vocab_size,
hidden_size=512,
num_layers=12,
num_heads=16,
dropout=0.1
)
model.load_state_dict(torch.load("Conv_GPT_finetuned_blended_skill.pth", map_location=torch.device('cpu')))
model.eval()
# Define generation function
def generate_text(prompt, max_new_tokens=30):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Ensure input sequence length does not exceed 512 (model's max_seq_len)
if input_ids.size(1) > 512:
input_ids = input_ids[:, :512]
generated_ids = input_ids
with torch.no_grad():
for _ in range(max_new_tokens):
logits = model(generated_ids)
next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
# Truncate if exceeding 512 tokens
if generated_ids.size(1) > 512:
generated_ids = generated_ids[:, -512:]
if tokenizer.decode(next_token.item()) == '\n':
break
return tokenizer.decode(generated_ids[0, len(input_ids[0]):]).strip()
# Chat function for Gradio
def chat(message, history):
prompt = f"User: {message}\nAssistant:"
response = generate_text(prompt)
return response
# Create Gradio interface
interface = gr.ChatInterface(
fn=chat,
title="Conv GPT πŸ’¬",
description="""Welcome to Conv GPT, a custom-trained transformer-based chatbot from SCRATCH! This model is trained on the DailyDialog dataset and fine-tuned on the BlendedTalk dataset to
provide conversational responses.""",
theme="default",
examples=["How are you?", "What is your favorite hobby?."]
)
# Launch the app
interface.launch(server_name="0.0.0.0", server_port=7860)