Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, BartModel | |
# Load the BART model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") | |
model = BartModel.from_pretrained("facebook/bart-base") | |
# Define the user input widget | |
user_input = gr.TextInput( | |
value="", | |
placeholder="Type something...", | |
rows=10, | |
cols=50, | |
fontsize=16, | |
padding=10, | |
border=True, | |
background="#f2f2f2" | |
) | |
# Define the chatbot output widget | |
chatbot_output = gr.TextOutput( | |
value="", | |
fontsize=16, | |
padding=10, | |
border=True, | |
background="#ffffff" | |
) | |
# Define the button to trigger the chatbot response | |
button = gr.Button( | |
label="Send", | |
fontsize=16, | |
padding=10, | |
border=True, | |
background="#4CAF50" | |
) | |
# Define the callback function to call when the button is clicked | |
def send_message(): | |
# Get the user input text | |
input_text = user_input.value | |
# Tokenize the input text | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# Run the input through the BART model to get the generated text | |
outputs = model(**inputs) | |
last_hidden_states = outputs.last_hidden_state | |
# Extract the generated text from the last hidden state | |
generated_text = tokenizer.generate(last_hidden_states, max_length=50, padding="max_length", truncation=True).strip() | |
# Update the chatbot output text | |
chatbot_output.value = generated_text | |
# Add the widgets to the layout | |
gr.layout( | |
gr.Row( | |
gr.Column(user_input), | |
gr.Column(button), | |
gr.Column(chatbot_output) | |
) | |
) | |
# Set up the button click event handler | |
button.on_click(send_message) | |
# Start the Gradio app | |
if __name__ == "__main__": | |
gr.run() |