File size: 2,035 Bytes
591ec58 8d5a3c5 591ec58 c0decd0 591ec58 1f3d130 c0decd0 a9bd10d 591ec58 0beee7d 591ec58 d8c75c8 a174693 d8c75c8 591ec58 f527b68 591ec58 d8c75c8 591ec58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import gradio as gr
from transformers import GPT2Tokenizer
from model import TransformerModel
import torch
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# Load model
model = TransformerModel(
vocab_size=tokenizer.vocab_size,
hidden_size=512,
num_layers=12,
num_heads=16,
dropout=0.1
)
model.load_state_dict(torch.load("Conv_GPT_finetuned_blended_skill.pth", map_location=torch.device('cpu')))
model.eval()
# Define generation function
def generate_text(prompt, max_new_tokens=30):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Ensure input sequence length does not exceed 512 (model's max_seq_len)
if input_ids.size(1) > 512:
input_ids = input_ids[:, :512]
generated_ids = input_ids
with torch.no_grad():
for _ in range(max_new_tokens):
logits = model(generated_ids)
next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
# Truncate if exceeding 512 tokens
if generated_ids.size(1) > 512:
generated_ids = generated_ids[:, -512:]
if tokenizer.decode(next_token.item()) == '\n':
break
return tokenizer.decode(generated_ids[0, len(input_ids[0]):]).strip()
# Chat function for Gradio
def chat(message, history):
prompt = f"User: {message}\nAssistant:"
response = generate_text(prompt)
return response
# Create Gradio interface
interface = gr.ChatInterface(
fn=chat,
title="Conv GPT 💬",
description="""Welcome to Conv GPT, a custom-trained transformer-based chatbot from SCRATCH! This model is trained on the DailyDialog dataset and fine-tuned on the BlendedTalk dataset to
provide conversational responses.""",
theme="default",
examples=["How are you?", "What is your favorite hobby?."]
)
# Launch the app
interface.launch(server_name="0.0.0.0", server_port=7860) |