|
import torch |
|
import gradio as gr |
|
from tokenizers import Tokenizer |
|
from transformers import PreTrainedTokenizerFast |
|
|
|
from transformer_chat import TransformerChatbot |
|
|
|
|
|
tokenizer_obj = Tokenizer.from_file("tokenizer.json") |
|
hf_tok = PreTrainedTokenizerFast( |
|
tokenizer_object=tokenizer_obj, |
|
unk_token="[UNK]", |
|
pad_token="[PAD]", |
|
cls_token="[CLS]", |
|
sep_token="[SEP]", |
|
mask_token="[MASK]" |
|
) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = TransformerChatbot( |
|
vocab_size=hf_tok.vocab_size, |
|
d_model=512, num_heads=8, d_ff=2048, |
|
num_encoder_layers=6, num_decoder_layers=6, |
|
num_roles=2, max_turns=16, num_slots=22, |
|
dropout=0.1 |
|
).to(device) |
|
model.load_state_dict(torch.load("atis_transformer.pt", map_location=device)) |
|
model.eval() |
|
|
|
|
|
def chat_fn(prompt): |
|
|
|
enc = hf_tok(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
src_ids = enc.input_ids.to(device) |
|
|
|
src_mask = None |
|
|
|
|
|
roles = torch.zeros_like(src_ids) |
|
turns = torch.zeros_like(src_ids) |
|
|
|
|
|
with torch.no_grad(): |
|
enc_out = model.encode(src_ids, roles, turns, src_mask) |
|
|
|
|
|
cls_id = hf_tok.cls_token_id |
|
sep_id = hf_tok.sep_token_id |
|
dec_input = torch.tensor([[cls_id]], device=device) |
|
dec_roles = torch.zeros_like(dec_input) |
|
dec_turns = torch.zeros_like(dec_input) |
|
|
|
generated = [] |
|
for step in range(50): |
|
T = dec_input.size(1) |
|
|
|
|
|
causal_mask = torch.triu(torch.ones((T, T), device=device), diagonal=1).bool() |
|
tgt_mask = causal_mask |
|
|
|
logits = model.decode(dec_input, enc_out, dec_roles, dec_turns, src_mask, tgt_mask) |
|
|
|
|
|
last_logits = logits[0, -1, :] |
|
|
|
|
|
if generated: |
|
for token_id in set(generated): |
|
last_logits[token_id] *= 0.7 |
|
|
|
|
|
temperature = 0.8 |
|
probs = torch.softmax(last_logits / temperature, dim=-1) |
|
next_id = torch.multinomial(probs, 1) |
|
|
|
|
|
token_text = hf_tok.decode([next_id.item()]) |
|
print(f"Step {step}: Generated token ID {next_id.item()} -> '{token_text}'") |
|
|
|
if next_id.item() == sep_id: |
|
print("Found SEP token, stopping generation") |
|
break |
|
|
|
generated.append(next_id.item()) |
|
dec_input = torch.cat([dec_input, next_id.unsqueeze(0)], dim=1) |
|
dec_roles = torch.cat([dec_roles, torch.zeros_like(next_id).unsqueeze(0)], dim=1) |
|
dec_turns = torch.cat([dec_turns, torch.zeros_like(next_id).unsqueeze(0)], dim=1) |
|
|
|
|
|
if len(generated) >= 3 and len(set(generated[-3:])) == 1: |
|
print("Detected repetition loop, stopping generation") |
|
break |
|
|
|
output_ids = [cls_id] + generated + [sep_id] |
|
reply = hf_tok.decode(output_ids, skip_special_tokens=True) |
|
|
|
return reply |
|
|
|
|
|
interface = gr.Interface( |
|
fn=chat_fn, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."), |
|
outputs="text", |
|
title="Transformer Chatbot Demo (currently trained with ATIS dataset)", |
|
description="Ask flight-related questions and get an answer." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch(share=True) |
|
|