Spaces:
Runtime error
Runtime error
File size: 1,722 Bytes
103a097 c2ba8e6 103a097 c2ba8e6 103a097 c2ba8e6 103a097 c2ba8e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import gradio as gr
import torch
from transformers import AutoTokenizer, BartModel
# Load the BART model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = BartModel.from_pretrained("facebook/bart-base")
# Define the user input widget
user_input = gr.TextInput(
value="",
placeholder="Type something...",
rows=10,
cols=50,
fontsize=16,
padding=10,
border=True,
background="#f2f2f2"
)
# Define the chatbot output widget
chatbot_output = gr.TextOutput(
value="",
fontsize=16,
padding=10,
border=True,
background="#ffffff"
)
# Define the button to trigger the chatbot response
button = gr.Button(
label="Send",
fontsize=16,
padding=10,
border=True,
background="#4CAF50"
)
# Define the callback function to call when the button is clicked
def send_message():
# Get the user input text
input_text = user_input.value
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt")
# Run the input through the BART model to get the generated text
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
# Extract the generated text from the last hidden state
generated_text = tokenizer.generate(last_hidden_states, max_length=50, padding="max_length", truncation=True).strip()
# Update the chatbot output text
chatbot_output.value = generated_text
# Add the widgets to the layout
gr.layout(
gr.Row(
gr.Column(user_input),
gr.Column(button),
gr.Column(chatbot_output)
)
)
# Set up the button click event handler
button.on_click(send_message)
# Start the Gradio app
if __name__ == "__main__":
gr.run() |