import gradio as gr from transformers import GPT2Tokenizer from model import TransformerModel import torch # Load tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token # Load model model = TransformerModel( vocab_size=tokenizer.vocab_size, hidden_size=512, num_layers=12, num_heads=16, dropout=0.1 ) model.load_state_dict(torch.load("Conv_GPT_finetuned_blended_skill.pth", map_location=torch.device('cpu'))) model.eval() # Define generation function def generate_text(prompt, max_new_tokens=30): input_ids = tokenizer.encode(prompt, return_tensors='pt') # Ensure input sequence length does not exceed 512 (model's max_seq_len) if input_ids.size(1) > 512: input_ids = input_ids[:, :512] generated_ids = input_ids with torch.no_grad(): for _ in range(max_new_tokens): logits = model(generated_ids) next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) generated_ids = torch.cat([generated_ids, next_token], dim=1) # Truncate if exceeding 512 tokens if generated_ids.size(1) > 512: generated_ids = generated_ids[:, -512:] if tokenizer.decode(next_token.item()) == '\n': break return tokenizer.decode(generated_ids[0, len(input_ids[0]):]).strip() # Chat function for Gradio def chat(message, history): prompt = f"User: {message}\nAssistant:" response = generate_text(prompt) return response # Create Gradio interface interface = gr.ChatInterface( fn=chat, title="Conv GPT 💬", description="""Welcome to Conv GPT, a custom-trained transformer-based chatbot from SCRATCH! This model is trained on the DailyDialog dataset and fine-tuned on the BlendedTalk dataset to provide conversational responses.""", theme="default", examples=["How are you?", "What is your favorite hobby?."] ) # Launch the app interface.launch(server_name="0.0.0.0", server_port=7860)