File size: 2,035 Bytes
591ec58
 
 
8d5a3c5
591ec58
 
 
 
c0decd0
591ec58
 
 
 
 
 
 
 
 
1f3d130
c0decd0
a9bd10d
591ec58
 
0beee7d
591ec58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8c75c8
a174693
d8c75c8
591ec58
f527b68
591ec58
 
d8c75c8
591ec58
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from transformers import GPT2Tokenizer
from model import TransformerModel
import torch

# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token


# Load model
model = TransformerModel(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,
    num_layers=12,
    num_heads=16,
    dropout=0.1
)
model.load_state_dict(torch.load("Conv_GPT_finetuned_blended_skill.pth", map_location=torch.device('cpu')))
model.eval()


# Define generation function
def generate_text(prompt, max_new_tokens=30):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    # Ensure input sequence length does not exceed 512 (model's max_seq_len)
    if input_ids.size(1) > 512:
        input_ids = input_ids[:, :512]
    generated_ids = input_ids
    with torch.no_grad():
        for _ in range(max_new_tokens):
            logits = model(generated_ids)
            next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            # Truncate if exceeding 512 tokens
            if generated_ids.size(1) > 512:
                generated_ids = generated_ids[:, -512:]
            if tokenizer.decode(next_token.item()) == '\n':
                break
    return tokenizer.decode(generated_ids[0, len(input_ids[0]):]).strip()

# Chat function for Gradio
def chat(message, history):
    prompt = f"User: {message}\nAssistant:"
    response = generate_text(prompt)
    return response

# Create Gradio interface
interface = gr.ChatInterface(
    fn=chat,
    title="Conv GPT 💬",
    description="""Welcome to Conv GPT, a custom-trained transformer-based chatbot from SCRATCH! This model is trained on the DailyDialog dataset and fine-tuned on the BlendedTalk dataset to
                   provide conversational responses.""",
    theme="default",
    examples=["How are you?", "What is your favorite hobby?."]
)


# Launch the app
interface.launch(server_name="0.0.0.0", server_port=7860)