Kvikontent commited on
Commit
1065e57
·
1 Parent(s): d0d1884

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -63
app.py CHANGED
@@ -1,71 +1,21 @@
1
  import gradio as gr
 
2
  import torch
3
- from transformers import AutoTokenizer, BartModel
4
 
5
- # Load the BART model and tokenizer
6
  tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
7
- model = BartModel.from_pretrained("facebook/bart-base")
8
 
9
- # Define the user input widget
10
- user_input = gr.TextInput(
11
- value="",
12
- placeholder="Type something...",
13
- rows=10,
14
- cols=50,
15
- fontsize=16,
16
- padding=10,
17
- border=True,
18
- background="#f2f2f2"
19
- )
20
 
21
- # Define the chatbot output widget
22
- chatbot_output = gr.TextOutput(
23
- value="",
24
- fontsize=16,
25
- padding=10,
26
- border=True,
27
- background="#ffffff"
28
- )
29
 
30
- # Define the button to trigger the chatbot response
31
- button = gr.Button(
32
- label="Send",
33
- fontsize=16,
34
- padding=10,
35
- border=True,
36
- background="#4CAF50"
37
- )
38
 
39
- # Define the callback function to call when the button is clicked
40
- def send_message():
41
- # Get the user input text
42
- input_text = user_input.value
43
-
44
- # Tokenize the input text
45
- inputs = tokenizer(input_text, return_tensors="pt")
46
-
47
- # Run the input through the BART model to get the generated text
48
- outputs = model(**inputs)
49
- last_hidden_states = outputs.last_hidden_state
50
-
51
- # Extract the generated text from the last hidden state
52
- generated_text = tokenizer.generate(last_hidden_states, max_length=50, padding="max_length", truncation=True).strip()
53
-
54
- # Update the chatbot output text
55
- chatbot_output.value = generated_text
56
-
57
- # Add the widgets to the layout
58
- gr.layout(
59
- gr.Row(
60
- gr.Column(user_input),
61
- gr.Column(button),
62
- gr.Column(chatbot_output)
63
- )
64
- )
65
-
66
- # Set up the button click event handler
67
- button.on_click(send_message)
68
-
69
- # Start the Gradio app
70
- if __name__ == "__main__":
71
- gr.run()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
 
4
 
5
+ # Load the BART tokenizer and model
6
  tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
7
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base")
8
 
9
+ def generate_response(user_input):
10
+ inputs = tokenizer(user_input, return_tensors="pt", max_length=512, truncation=True)
11
+ outputs = model.generate(**inputs)
12
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
13
+ return response
 
 
 
 
 
 
14
 
15
+ input_textbox = gr.Textbox(lines=10, label="Enter your text here")
16
+ output_textbox = gr.Textbox(label="Chatbot Response")
 
 
 
 
 
 
17
 
18
+ title = "Hugging Face BART Chatbot"
19
+ description = "This chatbot uses the Hugging Face BART model to generate responses based on user input."
 
 
 
 
 
 
20
 
21
+ gr.Interface(fn=generate_response, inputs=input_textbox, outputs=output_textbox, title=title, description=description).launch()