File size: 4,020 Bytes
35a6fb3
3c0fe0d
35a6fb3
 
3c0fe0d
35a6fb3
3c0fe0d
35a6fb3
 
 
 
 
 
 
 
 
 
3c0fe0d
35a6fb3
 
 
 
 
 
 
 
 
 
 
3c0fe0d
35a6fb3
 
 
 
 
 
 
3c0fe0d
35a6fb3
 
 
3c0fe0d
35a6fb3
 
 
3c0fe0d
35a6fb3
 
 
 
 
 
3c0fe0d
35a6fb3
 
 
 
 
 
 
3c0fe0d
35a6fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c0fe0d
35a6fb3
 
3c0fe0d
35a6fb3
 
 
 
 
 
 
 
 
 
3c0fe0d
 
35a6fb3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import gradio as gr
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast

from transformer_chat import TransformerChatbot

# Load tokenizer & wrap for HF API
tokenizer_obj = Tokenizer.from_file("tokenizer.json")
hf_tok = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer_obj,
    unk_token="[UNK]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    mask_token="[MASK]"
)

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerChatbot(
    vocab_size=hf_tok.vocab_size,
    d_model=512, num_heads=8, d_ff=2048,
    num_encoder_layers=6, num_decoder_layers=6,
    num_roles=2, max_turns=16, num_slots=22,
    dropout=0.1
).to(device)
model.load_state_dict(torch.load("atis_transformer.pt", map_location=device))
model.eval()

# Generation function
def chat_fn(prompt):
    # Encode user input
    enc = hf_tok(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
    src_ids  = enc.input_ids.to(device)
    # For cross-attention, we don't need to mask the encoder output
    src_mask = None

    # Roles & turns (user=0)
    roles = torch.zeros_like(src_ids)
    turns = torch.zeros_like(src_ids)

    # Encode
    with torch.no_grad():
        enc_out = model.encode(src_ids, roles, turns, src_mask)

        # Generate reply token-by-token
        cls_id = hf_tok.cls_token_id
        sep_id = hf_tok.sep_token_id
        dec_input = torch.tensor([[cls_id]], device=device)
        dec_roles = torch.zeros_like(dec_input)
        dec_turns = torch.zeros_like(dec_input)

        generated = []
        for step in range(50):
            T = dec_input.size(1)
            # Create causal mask for decoder (upper triangular = masked)
            # PyTorch's MultiheadAttention expects a 2D mask where True = masked
            causal_mask = torch.triu(torch.ones((T, T), device=device), diagonal=1).bool()
            tgt_mask = causal_mask

            logits = model.decode(dec_input, enc_out, dec_roles, dec_turns, src_mask, tgt_mask)
            
            # Get the last token's logits
            last_logits = logits[0, -1, :]
            
            # Apply repetition penalty
            if generated:
                for token_id in set(generated):
                    last_logits[token_id] *= 0.7  # Penalize repeated tokens
            
            # Sample with temperature instead of greedy decoding
            temperature = 0.8
            probs = torch.softmax(last_logits / temperature, dim=-1)
            next_id = torch.multinomial(probs, 1)
            
            # Debug: print the token being generated
            token_text = hf_tok.decode([next_id.item()])
            print(f"Step {step}: Generated token ID {next_id.item()} -> '{token_text}'")
            
            if next_id.item() == sep_id:
                print("Found SEP token, stopping generation")
                break
                
            generated.append(next_id.item())
            dec_input = torch.cat([dec_input, next_id.unsqueeze(0)], dim=1)
            dec_roles = torch.cat([dec_roles, torch.zeros_like(next_id).unsqueeze(0)], dim=1)
            dec_turns = torch.cat([dec_turns, torch.zeros_like(next_id).unsqueeze(0)], dim=1)
            
            # Early stopping if we're stuck in a loop
            if len(generated) >= 3 and len(set(generated[-3:])) == 1:
                print("Detected repetition loop, stopping generation")
                break

        output_ids = [cls_id] + generated + [sep_id]
        reply = hf_tok.decode(output_ids, skip_special_tokens=True)

    return reply

# Build Gradio interface
interface = gr.Interface(
    fn=chat_fn,
    inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
    outputs="text",
    title="Transformer Chatbot Demo (currently trained with ATIS dataset)",
    description="Ask flight-related questions and get an answer."
)

if __name__ == "__main__":
    interface.launch(share=True)